/*
 * Decompiled with CFR 0.152.
 */
package org.openrewrite.java.testing.mockito;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import org.jspecify.annotations.Nullable;
import org.openrewrite.Cursor;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Preconditions;
import org.openrewrite.Recipe;
import org.openrewrite.Tree;
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.ListUtils;
import org.openrewrite.java.AnnotationMatcher;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.VariableNameUtils;
import org.openrewrite.java.search.UsesMethod;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.Flag;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.Statement;
import org.openrewrite.java.tree.TypeUtils;

public class MockitoWhenOnStaticToMockStatic
extends Recipe {
    private static final AnnotationMatcher BEFORE = new AnnotationMatcher("org.junit.Before");
    private static final AnnotationMatcher BEFORE_CLASS = new AnnotationMatcher("org.junit.BeforeClass");
    private static final AnnotationMatcher AFTER = new AnnotationMatcher("org.junit.After");
    private static final AnnotationMatcher AFTER_CLASS = new AnnotationMatcher("org.junit.AfterClass");
    private static final MethodMatcher MOCKITO_WHEN = new MethodMatcher("org.mockito.Mockito when(..)");
    private int varCounter = 0;

    public String getDisplayName() {
        return "Replace `Mockito.when` on static (non mock) with try-with-resource with MockedStatic";
    }

    public String getDescription() {
        return "Replace `Mockito.when` on static (non mock) with try-with-resource with MockedStatic as Mockito4 no longer allows this. When `@Before` or `@BeforeClass` is used, a `close` method is added to either the `@After` or `@AfterClass` method. This change moves away from implicit bytecode manipulation for static method stubbing, making mocking behavior more explicit and scoped to avoid unintended side effects.";
    }

    public TreeVisitor<?, ExecutionContext> getVisitor() {
        return Preconditions.check((TreeVisitor)new UsesMethod(MOCKITO_WHEN), (TreeVisitor)new JavaIsoVisitor<ExecutionContext>(){

            public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
                List<Statement> newStatements = MockitoWhenOnStaticToMockStatic.isMethodDeclarationWithAnnotation((Statement)this.getCursor().firstEnclosing(J.MethodDeclaration.class), new AnnotationMatcher[]{BEFORE, BEFORE_CLASS}) ? this.maybeStatementsToMockedStatic(block, block.getStatements(), ctx) : this.maybeWrapStatementsInTryWithResourcesMockedStatic(block, block.getStatements(), ctx);
                J.Block b = super.visitBlock(block.withStatements(newStatements), (Object)ctx);
                return (J.Block)this.maybeAutoFormat((J)block, (J)b, ctx);
            }

            private List<Statement> maybeStatementsToMockedStatic(J.Block m, List<Statement> statements, ExecutionContext ctx) {
                ArrayList<Statement> list = new ArrayList<Statement>();
                for (Statement statement : statements) {
                    J.MethodInvocation whenArg = this.getWhenArg(statement);
                    if (whenArg != null) {
                        String className = this.getClassName(whenArg);
                        if (className == null) continue;
                        list.addAll(this.mockedStatic(m, (J.MethodInvocation)statement, className, whenArg, ctx));
                        continue;
                    }
                    list.add(statement);
                }
                return list;
            }

            private List<Statement> maybeWrapStatementsInTryWithResourcesMockedStatic(J.Block block, List<Statement> statements, ExecutionContext ctx) {
                AtomicBoolean restInTry = new AtomicBoolean(false);
                return ListUtils.map(statements, (index, statement) -> {
                    String className;
                    if (restInTry.get()) {
                        return null;
                    }
                    J.MethodInvocation whenArg = this.getWhenArg((Statement)statement);
                    if (whenArg != null && (className = this.getClassName(whenArg)) != null) {
                        Optional nameOfWrappingMockedStatic = MockitoWhenOnStaticToMockStatic.tryGetMatchedWrappingResourceName(this.getCursor(), className);
                        if (nameOfWrappingMockedStatic.isPresent()) {
                            return this.reuseMockedStatic(block, (J.MethodInvocation)statement, (String)nameOfWrappingMockedStatic.get(), whenArg, ctx);
                        }
                        restInTry.set(true);
                        return this.tryWithMockedStatic(block, statements, (Integer)index, (J.MethodInvocation)statement, className, whenArg, ctx);
                    }
                    return statement;
                });
            }

            private // Could not load outer class - annotation placement on inner may be incorrect
            @Nullable J.MethodInvocation getWhenArg(Statement statement) {
                J.MethodInvocation whenArg;
                J.MethodInvocation when;
                if (statement instanceof J.MethodInvocation && MOCKITO_WHEN.matches(((J.MethodInvocation)statement).getSelect()) && (when = (J.MethodInvocation)((J.MethodInvocation)statement).getSelect()) != null && when.getArguments().get(0) instanceof J.MethodInvocation && (whenArg = (J.MethodInvocation)when.getArguments().get(0)).getMethodType() != null && whenArg.getMethodType().hasFlags(new Flag[]{Flag.Static})) {
                    return whenArg;
                }
                return null;
            }

            private @Nullable String getClassName(J.MethodInvocation whenArg) {
                J.Identifier clazz = null;
                if (whenArg.getSelect() instanceof J.Identifier && ((J.Identifier)whenArg.getSelect()).getFieldType() == null) {
                    clazz = (J.Identifier)whenArg.getSelect();
                } else if (whenArg.getSelect() instanceof J.FieldAccess && ((J.FieldAccess)whenArg.getSelect()).getTarget() instanceof J.Identifier) {
                    clazz = (J.Identifier)((J.FieldAccess)whenArg.getSelect()).getTarget();
                }
                return clazz != null && clazz.getType() != null ? clazz.getSimpleName() : null;
            }

            private J.Try tryWithMockedStatic(J.Block block, List<Statement> statements, Integer index, J.MethodInvocation statement, String className, J.MethodInvocation whenArg, ExecutionContext ctx) {
                String variableName = VariableNameUtils.generateVariableName((String)("mock" + className + ++MockitoWhenOnStaticToMockStatic.this.varCounter), (Cursor)this.updateCursor((Tree)block), (VariableNameUtils.GenerationStrategy)VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER);
                Expression thenReturnArg = (Expression)statement.getArguments().get(0);
                J.Try try_ = (J.Try)((J.Block)this.javaTemplateMockStatic(String.format("try(MockedStatic<%1$s> %2$s = mockStatic(%1$s.class)) {\n    %2$s.when(() -> #{any()}).thenReturn(#{any()});\n}", className, variableName), ctx).apply(this.getCursor(), block.getCoordinates().firstStatement(), new Object[]{whenArg, thenReturnArg})).getStatements().get(0);
                List<Statement> precedingStatements = statements.subList(0, index);
                List handledStatements = ListUtils.concat(precedingStatements, (Object)try_);
                List<Statement> remainingStatements = statements.subList(index + 1, statements.size());
                List newStatements = ListUtils.concatAll((List)try_.getBody().getStatements(), this.maybeWrapStatementsInTryWithResourcesMockedStatic(block.withStatements(handledStatements), remainingStatements, ctx));
                return try_.withBody(try_.getBody().withStatements(newStatements)).withPrefix(statement.getPrefix());
            }

            private Statement reuseMockedStatic(J.Block block, J.MethodInvocation statement, String variableName, J.MethodInvocation whenArg, ExecutionContext ctx) {
                return (Statement)((J.Block)this.javaTemplateMockStatic(String.format("%1$s.when(() -> #{any()}).thenReturn(#{any()});", variableName), ctx).apply(this.getCursor(), block.getCoordinates().firstStatement(), new Object[]{whenArg, statement.getArguments().get(0)})).getStatements().get(0);
            }

            private List<Statement> mockedStatic(J.Block block, J.MethodInvocation statement, final String className, J.MethodInvocation whenArg, ExecutionContext ctx) {
                final boolean staticSetup = MockitoWhenOnStaticToMockStatic.isMethodDeclarationWithAnnotation((Statement)this.getCursor().firstEnclosing(J.MethodDeclaration.class), new AnnotationMatcher[]{BEFORE_CLASS});
                final String variableName = VariableNameUtils.generateVariableName((String)("mock" + className + ++MockitoWhenOnStaticToMockStatic.this.varCounter), (Cursor)this.updateCursor((Tree)block), (VariableNameUtils.GenerationStrategy)VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER);
                Expression thenReturnArg = (Expression)statement.getArguments().get(0);
                List<Statement> statements = ((J.Block)this.javaTemplateMockStatic(String.format("%2$s = mockStatic(%1$s.class);\n%2$s.when(() -> #{any()}).thenReturn(#{any()});", className, variableName), ctx).apply(this.getCursor(), block.getCoordinates().firstStatement(), new Object[]{whenArg, thenReturnArg})).getStatements().subList(0, 2);
                this.doAfterVisit((TreeVisitor)new JavaIsoVisitor<ExecutionContext>(){

                    public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext ctx) {
                        Optional<Statement> beforeMethod;
                        J.ClassDeclaration after = (J.ClassDeclaration)JavaTemplate.builder((String)String.format("private%s MockedStatic<%s> %s;", staticSetup ? " static" : "", className, variableName)).contextSensitive().build().apply(this.updateCursor((Tree)classDecl), classDecl.getBody().getCoordinates().firstStatement(), new Object[0]);
                        if (classDecl.getBody().getStatements().stream().noneMatch(it -> MockitoWhenOnStaticToMockStatic.isMethodDeclarationWithAnnotation(it, new AnnotationMatcher[]{AFTER, AFTER_CLASS})) && (beforeMethod = after.getBody().getStatements().stream().filter(it -> MockitoWhenOnStaticToMockStatic.isMethodDeclarationWithAnnotation(it, new AnnotationMatcher[]{BEFORE, BEFORE_CLASS})).findFirst()).isPresent()) {
                            this.maybeAddImport("org.junit.AfterClass");
                            this.maybeAddImport("org.junit.After");
                            after = (J.ClassDeclaration)JavaTemplate.builder((String)String.format("%s void tearDown() {}", staticSetup ? "@AfterClass public static" : "@After public")).imports(new String[]{staticSetup ? "org.junit.AfterClass" : "org.junit.After"}).javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, new String[]{"junit-4"})).build().apply(this.updateCursor((Tree)after), beforeMethod.get().getCoordinates().after(), new Object[0]);
                        }
                        J.ClassDeclaration cd = super.visitClassDeclaration(after, (Object)ctx);
                        return (J.ClassDeclaration)this.maybeAutoFormat((J)classDecl, (J)cd, ctx);
                    }

                    public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration methodDecl, ExecutionContext ctx) {
                        J.MethodDeclaration md = super.visitMethodDeclaration(methodDecl, (Object)ctx);
                        if (MockitoWhenOnStaticToMockStatic.isMethodDeclarationWithAnnotation((Statement)md, new AnnotationMatcher[]{AFTER, AFTER_CLASS})) {
                            return (J.MethodDeclaration)JavaTemplate.builder((String)(variableName + ".close();")).contextSensitive().build().apply(this.getCursor(), md.getBody().getCoordinates().lastStatement(), new Object[0]);
                        }
                        return md;
                    }
                });
                return statements;
            }

            private JavaTemplate javaTemplateMockStatic(String code, ExecutionContext ctx) {
                this.maybeAddImport("org.mockito.MockedStatic", false);
                this.maybeAddImport("org.mockito.Mockito", "mockStatic");
                return JavaTemplate.builder((String)code).contextSensitive().imports(new String[]{"org.mockito.MockedStatic"}).staticImports(new String[]{"org.mockito.Mockito.mockStatic"}).javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, new String[]{"mockito-core-5"})).build();
            }
        });
    }

    private static List<J.Try.Resource> getMatchingFilteredResources(@Nullable List<// Could not load outer class - annotation placement on inner may be incorrect
    J.Try.Resource> resources, String className) {
        if (resources != null) {
            return resources.stream().filter(res -> {
                J.VariableDeclarations vds = (J.VariableDeclarations)res.getVariableDeclarations();
                return TypeUtils.isAssignableTo((String)("org.mockito.MockedStatic<" + className + ">"), (JavaType)vds.getTypeAsFullyQualified());
            }).collect(Collectors.toList());
        }
        return Collections.emptyList();
    }

    private static Optional<String> tryGetMatchedWrappingResourceName(Cursor cursor, String className) {
        try {
            Cursor foundParentCursor = cursor.dropParentUntil(val -> {
                if (val instanceof J.Try) {
                    List<J.Try.Resource> filteredResources = MockitoWhenOnStaticToMockStatic.getMatchingFilteredResources(((J.Try)val).getResources(), className);
                    return !filteredResources.isEmpty();
                }
                return false;
            });
            return MockitoWhenOnStaticToMockStatic.getMatchingFilteredResources(((J.Try)foundParentCursor.getValue()).getResources(), className).stream().findFirst().map(res -> ((J.VariableDeclarations.NamedVariable)((J.VariableDeclarations)res.getVariableDeclarations()).getVariables().get(0)).getSimpleName());
        }
        catch (IllegalStateException e) {
            return Optional.empty();
        }
    }

    private static boolean isMethodDeclarationWithAnnotation(@Nullable Statement statement, AnnotationMatcher ... matchers) {
        if (statement instanceof J.MethodDeclaration) {
            return ((J.MethodDeclaration)statement).getLeadingAnnotations().stream().anyMatch(it -> Arrays.stream(matchers).anyMatch(m -> m.matches(it)));
        }
        return false;
    }
}

