/*
 * Decompiled with CFR 0.152.
 */
package com.github.database.rider.junit5;

import com.github.database.rider.core.RiderRunner;
import com.github.database.rider.core.RiderTestContext;
import com.github.database.rider.core.api.configuration.DBUnit;
import com.github.database.rider.core.api.connection.ConnectionHolder;
import com.github.database.rider.core.api.dataset.DataSet;
import com.github.database.rider.core.api.dataset.DataSetExecutor;
import com.github.database.rider.core.api.dataset.ExpectedDataSet;
import com.github.database.rider.core.api.leak.LeakHunter;
import com.github.database.rider.core.configuration.DBUnitConfig;
import com.github.database.rider.core.configuration.DataSetConfig;
import com.github.database.rider.core.connection.ConnectionHolderImpl;
import com.github.database.rider.core.connection.RiderDataSource;
import com.github.database.rider.core.dataset.DataSetExecutorImpl;
import com.github.database.rider.core.leak.LeakHunterFactory;
import com.github.database.rider.core.util.ClassUtils;
import com.github.database.rider.junit5.DBUnitTestContext;
import com.github.database.rider.junit5.JUnit5RiderTestContext;
import com.github.database.rider.junit5.api.DBRider;
import com.github.database.rider.junit5.util.EntityManagerProvider;
import io.micronaut.inject.qualifiers.Qualifiers;
import io.micronaut.test.extensions.junit5.MicronautJunit5Extension;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Optional;
import java.util.stream.Stream;
import javax.sql.DataSource;
import org.dbunit.DatabaseUnitException;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.platform.commons.util.AnnotationUtils;
import org.junit.platform.commons.util.StringUtils;
import org.springframework.context.ApplicationContext;
import org.springframework.test.context.junit.jupiter.SpringExtension;

public class DBUnitExtension
implements BeforeTestExecutionCallback,
AfterTestExecutionCallback,
BeforeEachCallback,
AfterEachCallback,
BeforeAllCallback,
AfterAllCallback {
    private static final ExtensionContext.Namespace NAMESPACE = ExtensionContext.Namespace.create((Object[])new Object[]{DBUnitExtension.class});
    private static final String JUNIT5_EXECUTOR = "junit5";
    private static final String EMPTY_STRING = "";

    public void beforeTestExecution(ExtensionContext extensionContext) throws Exception {
        this.clearEntityManager();
        String executorId = this.getExecutorId(extensionContext, null);
        ConnectionHolder connectionHolder = this.getTestConnection(extensionContext, executorId);
        DataSetExecutorImpl dataSetExecutor = DataSetExecutorImpl.instance((String)executorId, (ConnectionHolder)connectionHolder);
        DBUnitTestContext dbUnitTestContext = this.getTestContext(extensionContext);
        dbUnitTestContext.setExecutor((DataSetExecutor)dataSetExecutor);
        JUnit5RiderTestContext riderTestContext = new JUnit5RiderTestContext(dbUnitTestContext.getExecutor(), extensionContext);
        RiderRunner riderRunner = new RiderRunner();
        riderRunner.setup((RiderTestContext)riderTestContext);
        riderRunner.runBeforeTest((RiderTestContext)riderTestContext);
        DBUnitConfig dbUnitConfig = riderTestContext.getDataSetExecutor().getDBUnitConfig();
        if (dbUnitConfig.isLeakHunter().booleanValue()) {
            LeakHunter leakHunter = LeakHunterFactory.from((RiderDataSource)dataSetExecutor.getRiderDataSource(), (String)extensionContext.getRequiredTestMethod().getName());
            leakHunter.measureConnectionsBeforeExecution();
            dbUnitTestContext.setLeakHunter(leakHunter);
        }
    }

    private String getExecutorId(ExtensionContext extensionContext, DataSet dataSet) {
        Optional<DataSet> annDataSet = dataSet != null ? Optional.of(dataSet) : this.findDataSetAnnotation(extensionContext);
        String dataSourceBeanName = DBUnitExtension.getConfiguredDataSourceBeanName(extensionContext);
        String executionIdSuffix = dataSourceBeanName.isEmpty() ? EMPTY_STRING : "-" + dataSourceBeanName;
        return annDataSet.map(DataSet::executorId).filter(StringUtils::isNotBlank).map(id -> id + executionIdSuffix).orElseGet(() -> JUNIT5_EXECUTOR + executionIdSuffix);
    }

    private Optional<DataSet> findDataSetAnnotation(ExtensionContext extensionContext) {
        Optional testMethod = extensionContext.getTestMethod();
        if (testMethod.isPresent()) {
            Optional annDataSet = AnnotationUtils.findAnnotation((AnnotatedElement)((AnnotatedElement)testMethod.get()), DataSet.class);
            if (!annDataSet.isPresent()) {
                annDataSet = AnnotationUtils.findAnnotation((AnnotatedElement)extensionContext.getRequiredTestClass(), DataSet.class);
            }
            return annDataSet;
        }
        return Optional.empty();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void afterTestExecution(ExtensionContext extensionContext) throws Exception {
        DBUnitTestContext dbUnitTestContext = this.getTestContext(extensionContext);
        DBUnitConfig dbUnitConfig = dbUnitTestContext.getExecutor().getDBUnitConfig();
        JUnit5RiderTestContext riderTestContext = new JUnit5RiderTestContext(dbUnitTestContext.getExecutor(), extensionContext);
        RiderRunner riderRunner = new RiderRunner();
        try {
            if (dbUnitConfig != null && dbUnitConfig.isLeakHunter().booleanValue()) {
                LeakHunter leakHunter = dbUnitTestContext.getLeakHunter();
                leakHunter.checkConnectionsAfterExecution();
            }
            riderRunner.runAfterTest((RiderTestContext)riderTestContext);
        }
        finally {
            riderRunner.teardown((RiderTestContext)riderTestContext);
        }
    }

    private ConnectionHolder getTestConnection(ExtensionContext extensionContext, String executorId) {
        if (this.isSpringExtensionEnabled(extensionContext) && this.isSpringTestContextEnabled(extensionContext)) {
            return this.getConnectionFromSpringContext(extensionContext, executorId);
        }
        if (this.isMicronautExtensionEnabled(extensionContext) && DBUnitExtension.getMicronautApplicationContext(extensionContext).isPresent()) {
            return this.getConnectionFromMicronautContext(extensionContext, executorId);
        }
        return this.getConnectionFromTestClass(extensionContext, executorId);
    }

    private ConnectionHolder getConnectionFromSpringContext(ExtensionContext extensionContext, String executorId) {
        String configuredDataSourceBeanName = DBUnitExtension.getConfiguredDataSourceBeanName(extensionContext);
        DataSource dataSource = DBUnitExtension.getDataSourceFromSpringContext(extensionContext, configuredDataSourceBeanName);
        return this.getConnectionHolder(executorId, dataSource);
    }

    private ConnectionHolder getConnectionFromMicronautContext(ExtensionContext extensionContext, String executorId) {
        String configuredDataSourceBeanName = DBUnitExtension.getConfiguredDataSourceBeanName(extensionContext);
        DataSource dataSource = DBUnitExtension.getDataSourceFromMicronautContext(extensionContext, configuredDataSourceBeanName);
        return this.getConnectionHolder(executorId, dataSource);
    }

    private ConnectionHolder getConnectionHolder(String executorId, DataSource dataSource) {
        try {
            DataSetExecutorImpl dataSetExecutor = DataSetExecutorImpl.getExecutorById((String)executorId);
            if (this.isCachedConnection((DataSetExecutor)dataSetExecutor)) {
                return new ConnectionHolderImpl(dataSetExecutor.getRiderDataSource().getConnection());
            }
            return new ConnectionHolderImpl(dataSource.getConnection());
        }
        catch (SQLException e) {
            throw new RuntimeException("Could not get connection from DataSource.");
        }
    }

    private static DataSource getDataSourceFromSpringContext(ExtensionContext extensionContext, String beanName) {
        ApplicationContext context = SpringExtension.getApplicationContext((ExtensionContext)extensionContext);
        return beanName.isEmpty() ? (DataSource)context.getBean(DataSource.class) : (DataSource)context.getBean(beanName, DataSource.class);
    }

    private static DataSource getDataSourceFromMicronautContext(ExtensionContext extensionContext, String beanName) {
        Optional<io.micronaut.context.ApplicationContext> context = DBUnitExtension.getMicronautApplicationContext(extensionContext);
        if (context.isPresent()) {
            return beanName.isEmpty() ? (DataSource)context.get().getBean(DataSource.class) : (DataSource)context.get().getBean(DataSource.class, Qualifiers.byName((String)beanName));
        }
        throw new RuntimeException("Micronaut context is not available for test: " + ((Class)extensionContext.getTestClass().get()).getName());
    }

    private static String getConfiguredDataSourceBeanName(ExtensionContext extensionContext) {
        Optional testMethod = extensionContext.getTestMethod();
        if (testMethod.isPresent()) {
            Optional annotation = AnnotationUtils.findAnnotation((AnnotatedElement)((AnnotatedElement)testMethod.get()), DBRider.class);
            if (!annotation.isPresent()) {
                annotation = AnnotationUtils.findAnnotation((AnnotatedElement)extensionContext.getRequiredTestClass(), DBRider.class);
            }
            return annotation.map(DBRider::dataSourceBeanName).orElse(EMPTY_STRING);
        }
        return EMPTY_STRING;
    }

    private ConnectionHolder getConnectionFromTestClass(ExtensionContext extensionContext, String executorId) {
        DataSetExecutorImpl dataSetExecutor = DataSetExecutorImpl.getExecutorById((String)executorId);
        if (this.isCachedConnection((DataSetExecutor)dataSetExecutor)) {
            try {
                return new ConnectionHolderImpl(dataSetExecutor.getRiderDataSource().getConnection());
            }
            catch (SQLException sQLException) {
                // empty catch block
            }
        }
        Class testClass = extensionContext.getRequiredTestClass();
        ConnectionHolder conn = this.findConnectionFromTestClass(extensionContext, testClass);
        return conn;
    }

    private ConnectionHolder findConnectionFromTestClass(ExtensionContext extensionContext, Class<?> testClass) {
        try {
            Optional<Field> fieldFound = Arrays.stream(testClass.getDeclaredFields()).filter(f -> f.getType() == ConnectionHolder.class).findFirst();
            if (fieldFound.isPresent()) {
                Object testInstance;
                ConnectionHolder connectionHolder;
                Field field = fieldFound.get();
                if (!field.isAccessible()) {
                    field.setAccessible(true);
                }
                if ((connectionHolder = (ConnectionHolder)field.get(testInstance = Modifier.isStatic(field.getModifiers()) ? null : extensionContext.getRequiredTestInstance())) == null) {
                    throw new RuntimeException("ConnectionHolder not initialized correctly");
                }
                return connectionHolder;
            }
            Optional<Method> methodFound = Arrays.stream(testClass.getDeclaredMethods()).filter(m -> m.getReturnType() == ConnectionHolder.class).findFirst();
            if (methodFound.isPresent()) {
                ConnectionHolder connectionHolder;
                Method method = methodFound.get();
                if (!method.isAccessible()) {
                    method.setAccessible(true);
                }
                if ((connectionHolder = (ConnectionHolder)method.invoke(extensionContext.getRequiredTestInstance(), new Object[0])) == null) {
                    throw new RuntimeException("ConnectionHolder not initialized correctly");
                }
                return connectionHolder;
            }
        }
        catch (Exception e) {
            throw new RuntimeException("Could not get database connection for test " + testClass, e);
        }
        if (testClass.getSuperclass() != null) {
            return this.findConnectionFromTestClass(extensionContext, testClass.getSuperclass());
        }
        return null;
    }

    private DBUnitTestContext getTestContext(ExtensionContext context) {
        Class testClass = context.getRequiredTestClass();
        ExtensionContext.Store store = context.getStore(NAMESPACE);
        return (DBUnitTestContext)store.getOrComputeIfAbsent((Object)testClass, tc -> new DBUnitTestContext(), DBUnitTestContext.class);
    }

    private boolean isSpringExtensionEnabled(ExtensionContext extensionContext) {
        try {
            return ClassUtils.isOnClasspath((String)"org.springframework.test.context.junit.jupiter.SpringExtension") && extensionContext.getRoot().getStore(ExtensionContext.Namespace.create((Object[])new Object[]{SpringExtension.class})) != null;
        }
        catch (Exception e) {
            return false;
        }
    }

    private boolean isMicronautExtensionEnabled(ExtensionContext extensionContext) {
        try {
            return ClassUtils.isOnClasspath((String)"io.micronaut.test.extensions.junit5.MicronautJunit5Extension");
        }
        catch (Exception e) {
            return false;
        }
    }

    private boolean isSpringTestContextEnabled(ExtensionContext extensionContext) {
        if (!extensionContext.getTestClass().isPresent()) {
            return false;
        }
        ExtensionContext.Store springStore = extensionContext.getRoot().getStore(ExtensionContext.Namespace.create((Object[])new Object[]{SpringExtension.class}));
        return springStore != null && springStore.get(extensionContext.getTestClass().get()) != null;
    }

    private static Optional<io.micronaut.context.ApplicationContext> getMicronautApplicationContext(ExtensionContext extensionContext) {
        ExtensionContext.Store micronautStore = extensionContext.getRoot().getStore(ExtensionContext.Namespace.create((Object[])new Object[]{MicronautJunit5Extension.class}));
        if (micronautStore != null) {
            try {
                io.micronaut.context.ApplicationContext appContext = (io.micronaut.context.ApplicationContext)micronautStore.get(io.micronaut.context.ApplicationContext.class);
                if (appContext != null) {
                    return Optional.of(appContext);
                }
            }
            catch (ClassCastException classCastException) {
                // empty catch block
            }
        }
        return Optional.empty();
    }

    private boolean isCachedConnection(DataSetExecutor executor) {
        return executor != null && executor.getDBUnitConfig().isCacheConnection() != false;
    }

    private Optional<Method> findCallbackMethod(Class testClass, Class callback) {
        return Stream.of(testClass.getMethods()).filter(m -> m.getAnnotation(callback) != null).findFirst();
    }

    private Optional<Method> findSuperclassCallbackMethod(Class testClass, Class callback) {
        Class testSuperclass = testClass.getSuperclass();
        if (testSuperclass != null) {
            return this.findCallbackMethod(testSuperclass, callback);
        }
        return Optional.empty();
    }

    public void beforeEach(ExtensionContext extensionContext) throws Exception {
        Optional<Method> callbackMethod;
        if (extensionContext.getTestClass().isPresent() && (callbackMethod = this.findCallbackMethod((Class)extensionContext.getTestClass().get(), BeforeEach.class)).isPresent()) {
            this.executeDataSetForCallback(extensionContext, BeforeEach.class, callbackMethod.get());
            this.executeExpectedDataSetForCallback(extensionContext, BeforeEach.class, callbackMethod.get());
        }
    }

    public void afterEach(ExtensionContext extensionContext) throws Exception {
        Optional<Method> callbackMethod;
        if (extensionContext.getTestClass().isPresent() && (callbackMethod = this.findCallbackMethod((Class)extensionContext.getTestClass().get(), AfterEach.class)).isPresent()) {
            this.executeDataSetForCallback(extensionContext, AfterEach.class, callbackMethod.get());
            this.executeExpectedDataSetForCallback(extensionContext, AfterEach.class, callbackMethod.get());
        }
    }

    public void beforeAll(ExtensionContext extensionContext) throws Exception {
        Optional<Method> callbackMethod;
        if (extensionContext.getTestClass().isPresent() && (callbackMethod = this.findCallbackMethod((Class)extensionContext.getTestClass().get(), BeforeAll.class)).isPresent()) {
            this.executeDataSetForCallback(extensionContext, BeforeAll.class, callbackMethod.get());
            this.executeExpectedDataSetForCallback(extensionContext, BeforeAll.class, callbackMethod.get());
        }
    }

    public void afterAll(ExtensionContext extensionContext) throws Exception {
        Optional<Method> callbackMethod;
        if (extensionContext.getTestClass().isPresent() && (callbackMethod = this.findCallbackMethod((Class)extensionContext.getTestClass().get(), AfterAll.class)).isPresent()) {
            this.executeDataSetForCallback(extensionContext, AfterAll.class, callbackMethod.get());
            this.executeExpectedDataSetForCallback(extensionContext, AfterAll.class, callbackMethod.get());
        }
    }

    private void executeDataSetForCallback(ExtensionContext extensionContext, Class callbackAnnotation, Method callbackMethod) {
        Optional<Method> superclassCallbackMethod;
        Class testClass = (Class)extensionContext.getTestClass().get();
        Optional dataSetAnnotation = AnnotationUtils.findAnnotation((AnnotatedElement)callbackMethod, DataSet.class);
        if (!dataSetAnnotation.isPresent() && (superclassCallbackMethod = this.findSuperclassCallbackMethod(testClass, callbackAnnotation)).isPresent()) {
            dataSetAnnotation = AnnotationUtils.findAnnotation((AnnotatedElement)superclassCallbackMethod.get(), DataSet.class);
        }
        if (dataSetAnnotation.isPresent()) {
            DataSet dataSet;
            this.clearEntityManager();
            DBUnitConfig dbUnitConfig = this.resolveDbUnitConfig(callbackAnnotation, callbackMethod, testClass);
            if (dbUnitConfig.isMergeDataSets().booleanValue()) {
                Optional classLevelDataSetAnnotation = AnnotationUtils.findAnnotation((AnnotatedElement)testClass, DataSet.class);
                dataSet = this.resolveDataSet(dataSetAnnotation, classLevelDataSetAnnotation);
            } else {
                dataSet = (DataSet)dataSetAnnotation.get();
            }
            String executorId = this.getExecutorId(extensionContext, dataSet);
            ConnectionHolder connectionHolder = this.getTestConnection(extensionContext, executorId);
            DataSetExecutorImpl dataSetExecutor = DataSetExecutorImpl.instance((String)executorId, (ConnectionHolder)connectionHolder, (DBUnitConfig)dbUnitConfig);
            dataSetExecutor.createDataSet(new DataSetConfig().from(dataSet));
        }
    }

    private void executeExpectedDataSetForCallback(ExtensionContext extensionContext, Class callbackAnnotation, Method callbackMethod) throws DatabaseUnitException {
        Optional<Method> superclassCallbackMethod;
        Class testClass = (Class)extensionContext.getTestClass().get();
        Optional expectedDataSetAnnotation = AnnotationUtils.findAnnotation((AnnotatedElement)callbackMethod, ExpectedDataSet.class);
        if (!expectedDataSetAnnotation.isPresent() && (superclassCallbackMethod = this.findSuperclassCallbackMethod(testClass, callbackAnnotation)).isPresent()) {
            expectedDataSetAnnotation = AnnotationUtils.findAnnotation((AnnotatedElement)superclassCallbackMethod.get(), ExpectedDataSet.class);
        }
        if (expectedDataSetAnnotation.isPresent()) {
            ExpectedDataSet expectedDataSet = (ExpectedDataSet)expectedDataSetAnnotation.get();
            DBUnitConfig dbUnitConfig = this.resolveDbUnitConfig(callbackAnnotation, callbackMethod, testClass);
            String executorId = this.getExecutorId(extensionContext, null);
            ConnectionHolder connectionHolder = this.getTestConnection(extensionContext, executorId);
            DataSetExecutorImpl dataSetExecutor = DataSetExecutorImpl.instance((String)executorId, (ConnectionHolder)connectionHolder, (DBUnitConfig)dbUnitConfig);
            dataSetExecutor.compareCurrentDataSetWith(new DataSetConfig(expectedDataSet.value()).disableConstraints(true).datasetProvider(expectedDataSet.provider()), expectedDataSet.ignoreCols(), expectedDataSet.replacers(), expectedDataSet.orderBy(), expectedDataSet.compareOperation());
        }
    }

    private DBUnitConfig resolveDbUnitConfig(Class callbackAnnotation, Method callbackMethod, Class testClass) {
        Optional<Method> superclassCallbackMethod;
        Optional dbUnitAnnotation = AnnotationUtils.findAnnotation((AnnotatedElement)callbackMethod, DBUnit.class);
        if (!dbUnitAnnotation.isPresent()) {
            dbUnitAnnotation = AnnotationUtils.findAnnotation((AnnotatedElement)testClass, DBUnit.class);
        }
        if (!dbUnitAnnotation.isPresent() && (superclassCallbackMethod = this.findSuperclassCallbackMethod(testClass, callbackAnnotation)).isPresent()) {
            dbUnitAnnotation = AnnotationUtils.findAnnotation((AnnotatedElement)superclassCallbackMethod.get(), DBUnit.class);
        }
        if (!dbUnitAnnotation.isPresent() && testClass.getSuperclass() != null) {
            dbUnitAnnotation = AnnotationUtils.findAnnotation(testClass.getSuperclass(), DBUnit.class);
        }
        return dbUnitAnnotation.isPresent() ? DBUnitConfig.from((DBUnit)((DBUnit)dbUnitAnnotation.get())) : DBUnitConfig.fromGlobalConfig();
    }

    private DataSet resolveDataSet(Optional<DataSet> methodLevelDataSet, Optional<DataSet> classLevelDataSet) {
        if (classLevelDataSet.isPresent()) {
            return com.github.database.rider.core.util.AnnotationUtils.mergeDataSetAnnotations((DataSet)classLevelDataSet.get(), (DataSet)methodLevelDataSet.get());
        }
        return methodLevelDataSet.get();
    }

    private void clearEntityManager() {
        if (EntityManagerProvider.isEntityManagerActive()) {
            EntityManagerProvider.em().clear();
        }
    }
}

