/*
 * Copyright 2022 the original author or authors.
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * https://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.openrewrite.java.template.internal;

import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Type;
import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.JCTree.JCFieldAccess;
import com.sun.tools.javac.tree.TreeScanner;
import org.jspecify.annotations.Nullable;

import javax.tools.JavaFileObject;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class ClasspathJarNameDetector extends TreeScanner {
    private final Set<String> jarNames = new LinkedHashSet<>();

    /**
     * Locate types that are directly referred to by name in the
     * given tree and therefore need an import in the template.
     *
     * @return The list of imports to add.
     */
    public Set<String> classpathFor(JCTree input) {
        scan(input);
        return jarNames;
    }

    private void addJarNameFor(Symbol owner) {
        Symbol.ClassSymbol enclClass = owner instanceof Symbol.ClassSymbol ? (Symbol.ClassSymbol) owner : owner.enclClass();
        while (enclClass.enclClass() != null && enclClass.enclClass() != enclClass) {
            enclClass = enclClass.enclClass();
        }
        JavaFileObject classfile = enclClass.classfile;
        if (classfile != null) {
            String uriStr = classfile.toUri().toString();
            Matcher matcher = Pattern.compile("([^/]*)?\\.jar!/").matcher(uriStr);
            if (matcher.find()) {
                String jarName = matcher.group(1);
                // Ignore when `@Matches` on arguments tries to add rewrite-templating, which is implied present
                if (jarName.startsWith("rewrite-templating") || jarName.startsWith("error_prone_core")) {
                    return;
                }
                jarNames.add(jarName
                        // Retain major version number, to avoid `log4j` conflict between `log4j-1` and `log4j2-1`
                        .replaceFirst("(-\\d+).*?$", "$1"));
            }
        }
    }

    @Override
    public void scan(JCTree tree) {
        // Detect fully qualified classes
        if (tree instanceof JCFieldAccess &&
                ((JCFieldAccess) tree).sym instanceof Symbol.ClassSymbol &&
                Character.isUpperCase(((JCFieldAccess) tree).getIdentifier().toString().charAt(0))) {
            addJarNameFor(((JCFieldAccess) tree).sym);
        }

        // Detect method invocations and their types
        if (tree instanceof JCTree.JCMethodInvocation) {
            JCTree.JCMethodInvocation invocation = (JCTree.JCMethodInvocation) tree;
            Symbol.MethodSymbol methodSym = null;

            if (invocation.meth instanceof JCFieldAccess) {
                JCFieldAccess methodAccess = (JCFieldAccess) invocation.meth;
                if (methodAccess.sym instanceof Symbol.MethodSymbol) {
                    methodSym = (Symbol.MethodSymbol) methodAccess.sym;
                }

                // Add jar for the receiver type and its transitive dependencies
                if (methodAccess.selected != null && methodAccess.selected.type != null) {
                    addTypeAndTransitiveDependencies(methodAccess.selected.type);
                }
            } else if (invocation.meth instanceof JCTree.JCIdent) {
                // Handle unqualified method calls (e.g., from static imports)
                JCTree.JCIdent methodIdent = (JCTree.JCIdent) invocation.meth;
                if (methodIdent.sym instanceof Symbol.MethodSymbol) {
                    methodSym = (Symbol.MethodSymbol) methodIdent.sym;
                }
            }

            if (methodSym != null) {
                // Add jar for the method's owner class
                addJarNameFor(methodSym.owner);

                // Add jar for the return type
                if (methodSym.getReturnType() != null) {
                    addTypeAndTransitiveDependencies(methodSym.getReturnType());
                }

                // Add jars for exception types
                for (Type thrownType : methodSym.getThrownTypes()) {
                    addTypeAndTransitiveDependencies(thrownType);
                }
            }
        }

        // Detect identifiers that reference classes
        if (tree instanceof JCTree.JCIdent) {
            JCTree.JCIdent ident = (JCTree.JCIdent) tree;
            if (ident.sym instanceof Symbol.ClassSymbol) {
                Symbol.ClassSymbol classSym = (Symbol.ClassSymbol) ident.sym;
                addJarNameFor(classSym);

                // Add transitive dependencies through inheritance
                addTypeAndTransitiveDependencies(classSym.type);
            }
        }

        super.scan(tree);
    }

    private void addTypeAndTransitiveDependencies(@Nullable Type type) {
        if (type == null) {
            return;
        }

        if (type.tsym instanceof Symbol.ClassSymbol) {
            Symbol.ClassSymbol classSym = (Symbol.ClassSymbol) type.tsym;
            addJarNameFor(classSym);

            // Check superclass recursively
            Type superType = classSym.getSuperclass();
            if (superType != null && superType.tsym != null) {
                addTypeAndTransitiveDependencies(superType);
            }

            // Check interfaces recursively
            for (Type iface : classSym.getInterfaces()) {
                if (iface.tsym != null) {
                    addTypeAndTransitiveDependencies(iface);
                }
            }
        }
    }

}
