/*
 * Decompiled with CFR 0.152.
 */
package org.openrewrite.java.testing.junitassertj;

import java.util.List;
import org.openrewrite.AutoConfigure;
import org.openrewrite.RefactorVisitor;
import org.openrewrite.java.AutoFormat;
import org.openrewrite.java.JavaIsoRefactorVisitor;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.Flag;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.MethodTypeBuilder;
import org.openrewrite.java.tree.TypeUtils;

@AutoConfigure
public class AssertArrayEqualsToAssertThat
extends JavaIsoRefactorVisitor {
    private static final String JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME = "org.junit.jupiter.api.Assertions";
    private static final String ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME = "org.assertj.core.api.Assertions";
    private static final String ASSERTJ_ASSERT_THAT_METHOD_NAME = "assertThat";
    private static final String ASSERTJ_WITHIN_METHOD_NAME = "within";
    private static final MethodMatcher JUNIT_ASSERT_EQUALS_MATCHER = new MethodMatcher("org.junit.jupiter.api.Assertions assertArrayEquals(..)");
    private static final JavaType ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT = MethodTypeBuilder.newMethodType().declaringClass("org.assertj.core.api.Assertions").flags(new Flag[]{Flag.Public, Flag.Static}).name("*").build();

    public AssertArrayEqualsToAssertThat() {
        this.setCursoringOn();
    }

    public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method) {
        J.MethodInvocation replacement;
        J.MethodInvocation original = super.visitMethodInvocation(method);
        if (!JUNIT_ASSERT_EQUALS_MATCHER.matches(method)) {
            return original;
        }
        List originalArgs = original.getArgs().getArgs();
        Expression expected = (Expression)originalArgs.get(0);
        Expression actual = (Expression)originalArgs.get(1);
        if (originalArgs.size() == 2) {
            replacement = this.assertSimple(actual, expected);
        } else if (originalArgs.size() == 3 && !this.isFloatingPointType((Expression)originalArgs.get(2))) {
            replacement = this.assertWithMessage(actual, expected, (Expression)originalArgs.get(2));
        } else if (originalArgs.size() == 3) {
            replacement = this.assertFloatingPointDelta(actual, expected, (Expression)originalArgs.get(2));
            this.maybeAddImport(ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME, ASSERTJ_WITHIN_METHOD_NAME);
        } else {
            replacement = this.assertFloatingPointDeltaWithMessage(actual, expected, (Expression)originalArgs.get(2), (Expression)originalArgs.get(3));
            this.maybeAddImport(ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME, ASSERTJ_WITHIN_METHOD_NAME);
        }
        this.maybeAddImport(ASSERTJ_QUALIFIED_ASSERTIONS_CLASS_NAME, ASSERTJ_ASSERT_THAT_METHOD_NAME);
        this.maybeRemoveImport(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME);
        this.andThen((RefactorVisitor)new AutoFormat(new J[]{replacement}));
        return replacement;
    }

    private J.MethodInvocation assertSimple(Expression actual, Expression expected) {
        List statements = this.treeBuilder.buildSnippet(this.getCursor(), String.format("assertThat(%s).containsExactly(%s);", actual.printTrimmed(), expected.printTrimmed()), new JavaType[]{ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT});
        return (J.MethodInvocation)statements.get(0);
    }

    private J.MethodInvocation assertWithMessage(Expression actual, Expression expected, Expression message) {
        String messageAs = TypeUtils.isString((JavaType)message.getType()) ? "as" : "withFailMessage";
        List statements = this.treeBuilder.buildSnippet(this.getCursor(), String.format("assertThat(%s).%s(%s).containsExactly(%s);", actual.printTrimmed(), messageAs, message.printTrimmed(), expected.printTrimmed()), new JavaType[]{ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT});
        return (J.MethodInvocation)statements.get(0);
    }

    private J.MethodInvocation assertFloatingPointDelta(Expression actual, Expression expected, Expression delta) {
        List statements = this.treeBuilder.buildSnippet(this.getCursor(), String.format("assertThat(%s).containsExactly(%s, within(%s));", actual.printTrimmed(), expected.printTrimmed(), delta.printTrimmed()), new JavaType[]{ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT});
        return (J.MethodInvocation)statements.get(0);
    }

    private J.MethodInvocation assertFloatingPointDeltaWithMessage(Expression actual, Expression expected, Expression delta, Expression message) {
        String messageAs = TypeUtils.isString((JavaType)message.getType()) ? "as" : "withFailMessage";
        List statements = this.treeBuilder.buildSnippet(this.getCursor(), String.format("assertThat(%s).%s(%s).containsExactly(%s, within(%s));", actual.printTrimmed(), messageAs, message.printTrimmed(), expected.printTrimmed(), delta.printTrimmed()), new JavaType[]{ASSERTJ_ASSERTIONS_WILDCARD_STATIC_IMPORT});
        return (J.MethodInvocation)statements.get(0);
    }

    private boolean isFloatingPointType(Expression expression) {
        JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified((JavaType)expression.getType());
        if (fullyQualified != null) {
            String typeName = fullyQualified.getFullyQualifiedName();
            return typeName.equals("java.lang.Double") || typeName.equals("java.lang.Float");
        }
        JavaType.Primitive parameterType = TypeUtils.asPrimitive((JavaType)expression.getType());
        return parameterType == JavaType.Primitive.Double || parameterType == JavaType.Primitive.Float;
    }
}

