/*
 * Decompiled with CFR 0.152.
 */
package com.github.database.rider.junit5;

import com.github.database.rider.core.RiderRunner;
import com.github.database.rider.core.RiderTestContext;
import com.github.database.rider.core.api.connection.ConnectionHolder;
import com.github.database.rider.core.api.dataset.DataSet;
import com.github.database.rider.core.api.dataset.DataSetExecutor;
import com.github.database.rider.core.api.leak.LeakHunter;
import com.github.database.rider.core.configuration.DBUnitConfig;
import com.github.database.rider.core.configuration.DataSetConfig;
import com.github.database.rider.core.connection.ConnectionHolderImpl;
import com.github.database.rider.core.connection.RiderDataSource;
import com.github.database.rider.core.dataset.DataSetExecutorImpl;
import com.github.database.rider.core.leak.LeakHunterFactory;
import com.github.database.rider.core.util.ClassUtils;
import com.github.database.rider.junit5.DBUnitTestContext;
import com.github.database.rider.junit5.JUnit5RiderTestContext;
import com.github.database.rider.junit5.api.DBRider;
import com.github.database.rider.junit5.util.EntityManagerProvider;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Optional;
import java.util.stream.Stream;
import javax.sql.DataSource;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.platform.commons.util.AnnotationUtils;
import org.junit.platform.commons.util.StringUtils;
import org.springframework.context.ApplicationContext;
import org.springframework.test.context.junit.jupiter.SpringExtension;

public class DBUnitExtension
implements BeforeTestExecutionCallback,
AfterTestExecutionCallback,
BeforeEachCallback {
    private static final ExtensionContext.Namespace namespace = ExtensionContext.Namespace.create((Object[])new Object[]{DBUnitExtension.class});
    private static final String JUNIT5_EXECUTOR = "junit5";
    private static final String EMPTY_STRING = "";

    public void beforeTestExecution(ExtensionContext extensionContext) throws Exception {
        this.clearEntityManager();
        String executorId = this.getExecutorId(extensionContext, null);
        ConnectionHolder connectionHolder = this.getTestConnection(extensionContext, executorId);
        DataSetExecutorImpl dataSetExecutor = DataSetExecutorImpl.instance((String)executorId, (ConnectionHolder)connectionHolder);
        DBUnitTestContext dbUnitTestContext = this.getTestContext(extensionContext);
        dbUnitTestContext.setExecutor((DataSetExecutor)dataSetExecutor);
        JUnit5RiderTestContext riderTestContext = new JUnit5RiderTestContext(dbUnitTestContext.getExecutor(), extensionContext);
        RiderRunner riderRunner = new RiderRunner();
        riderRunner.setup((RiderTestContext)riderTestContext);
        riderRunner.runBeforeTest((RiderTestContext)riderTestContext);
        DBUnitConfig dbUnitConfig = riderTestContext.getDataSetExecutor().getDBUnitConfig();
        if (dbUnitConfig.isLeakHunter().booleanValue()) {
            LeakHunter leakHunter = LeakHunterFactory.from((RiderDataSource)dataSetExecutor.getRiderDataSource(), (String)extensionContext.getRequiredTestMethod().getName());
            leakHunter.measureConnectionsBeforeExecution();
            dbUnitTestContext.setLeakHunter(leakHunter);
        }
    }

    private String getExecutorId(ExtensionContext extensionContext, DataSet dataSet) {
        Optional<Object> annDataSet = Optional.empty();
        annDataSet = dataSet != null ? Optional.of(dataSet) : this.findDataSetAnnotation(extensionContext);
        String dataSourceBeanName = DBUnitExtension.getConfiguredDataSourceBeanName(extensionContext);
        String executionIdSuffix = dataSourceBeanName.isEmpty() ? EMPTY_STRING : "-" + dataSourceBeanName;
        return annDataSet.map(DataSet::executorId).filter(StringUtils::isNotBlank).map(id -> id + executionIdSuffix).orElseGet(() -> JUNIT5_EXECUTOR + executionIdSuffix);
    }

    private Optional<DataSet> findDataSetAnnotation(ExtensionContext extensionContext) {
        Optional annDataSet = AnnotationUtils.findAnnotation((AnnotatedElement)extensionContext.getRequiredTestMethod(), DataSet.class);
        if (!annDataSet.isPresent()) {
            annDataSet = AnnotationUtils.findAnnotation((AnnotatedElement)extensionContext.getRequiredTestClass(), DataSet.class);
        }
        return annDataSet;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void afterTestExecution(ExtensionContext extensionContext) throws Exception {
        DBUnitTestContext dbUnitTestContext = this.getTestContext(extensionContext);
        DBUnitConfig dbUnitConfig = dbUnitTestContext.getExecutor().getDBUnitConfig();
        JUnit5RiderTestContext riderTestContext = new JUnit5RiderTestContext(dbUnitTestContext.getExecutor(), extensionContext);
        RiderRunner riderRunner = new RiderRunner();
        try {
            if (dbUnitConfig != null && dbUnitConfig.isLeakHunter().booleanValue()) {
                LeakHunter leakHunter = dbUnitTestContext.getLeakHunter();
                leakHunter.checkConnectionsAfterExecution();
            }
            riderRunner.runAfterTest((RiderTestContext)riderTestContext);
        }
        finally {
            riderRunner.teardown((RiderTestContext)riderTestContext);
            this.afterEach(extensionContext);
        }
    }

    private ConnectionHolder getTestConnection(ExtensionContext extensionContext, String executorId) {
        if (this.isSpringExtensionEnabled() && this.isSpringTestContextEnabled(extensionContext)) {
            return this.getConnectionFromSpringContext(extensionContext, executorId);
        }
        return this.getConnectionFromTestClass(extensionContext, executorId);
    }

    private ConnectionHolder getConnectionFromSpringContext(ExtensionContext extensionContext, String executorId) {
        String configuredDataSourceBeanName = DBUnitExtension.getConfiguredDataSourceBeanName(extensionContext);
        DataSource dataSource = DBUnitExtension.getDataSource(extensionContext, configuredDataSourceBeanName);
        try {
            DataSetExecutorImpl dataSetExecutor = DataSetExecutorImpl.getExecutorById((String)executorId);
            if (this.isCachedConnection((DataSetExecutor)dataSetExecutor)) {
                return new ConnectionHolderImpl(dataSetExecutor.getRiderDataSource().getConnection());
            }
            return new ConnectionHolderImpl(dataSource.getConnection());
        }
        catch (SQLException e) {
            throw new RuntimeException("Could not get connection from DataSource.");
        }
    }

    private static DataSource getDataSource(ExtensionContext extensionContext, String beanName) {
        ApplicationContext context = SpringExtension.getApplicationContext((ExtensionContext)extensionContext);
        return beanName.isEmpty() ? (DataSource)context.getBean(DataSource.class) : (DataSource)context.getBean(beanName, DataSource.class);
    }

    private static String getConfiguredDataSourceBeanName(ExtensionContext extensionContext) {
        Optional annotation = AnnotationUtils.findAnnotation((AnnotatedElement)extensionContext.getRequiredTestMethod(), DBRider.class);
        if (!annotation.isPresent()) {
            annotation = AnnotationUtils.findAnnotation((AnnotatedElement)extensionContext.getRequiredTestClass(), DBRider.class);
        }
        return annotation.map(DBRider::dataSourceBeanName).orElse(EMPTY_STRING);
    }

    private ConnectionHolder getConnectionFromTestClass(ExtensionContext extensionContext, String executorId) {
        DataSetExecutorImpl dataSetExecutor = DataSetExecutorImpl.getExecutorById((String)executorId);
        if (this.isCachedConnection((DataSetExecutor)dataSetExecutor)) {
            try {
                return new ConnectionHolderImpl(dataSetExecutor.getRiderDataSource().getConnection());
            }
            catch (SQLException sQLException) {
                // empty catch block
            }
        }
        Class testClass = extensionContext.getRequiredTestClass();
        ConnectionHolder conn = this.findConnectionFromTestClass(extensionContext, testClass);
        return conn;
    }

    private ConnectionHolder findConnectionFromTestClass(ExtensionContext extensionContext, Class<?> testClass) {
        try {
            Optional<Field> fieldFound = Arrays.stream(testClass.getDeclaredFields()).filter(f -> f.getType() == ConnectionHolder.class).findFirst();
            if (fieldFound.isPresent()) {
                ConnectionHolder connectionHolder;
                Field field = fieldFound.get();
                if (!field.isAccessible()) {
                    field.setAccessible(true);
                }
                if ((connectionHolder = (ConnectionHolder)field.get(extensionContext.getRequiredTestInstance())) == null) {
                    throw new RuntimeException("ConnectionHolder not initialized correctly");
                }
                return connectionHolder;
            }
            Optional<Method> methodFound = Arrays.stream(testClass.getDeclaredMethods()).filter(m -> m.getReturnType() == ConnectionHolder.class).findFirst();
            if (methodFound.isPresent()) {
                ConnectionHolder connectionHolder;
                Method method = methodFound.get();
                if (!method.isAccessible()) {
                    method.setAccessible(true);
                }
                if ((connectionHolder = (ConnectionHolder)method.invoke(extensionContext.getRequiredTestInstance(), new Object[0])) == null) {
                    throw new RuntimeException("ConnectionHolder not initialized correctly");
                }
                return connectionHolder;
            }
        }
        catch (Exception e) {
            throw new RuntimeException("Could not get database connection for test " + testClass, e);
        }
        if (testClass.getSuperclass() != null) {
            return this.findConnectionFromTestClass(extensionContext, testClass.getSuperclass());
        }
        return null;
    }

    private DBUnitTestContext getTestContext(ExtensionContext context) {
        Class testClass = context.getRequiredTestClass();
        ExtensionContext.Store store = context.getStore(namespace);
        return (DBUnitTestContext)store.getOrComputeIfAbsent((Object)testClass, tc -> new DBUnitTestContext(), DBUnitTestContext.class);
    }

    private boolean isSpringExtensionEnabled() {
        return ClassUtils.isOnClasspath((String)"org.springframework.test.context.junit.jupiter.SpringExtension");
    }

    private boolean isSpringTestContextEnabled(ExtensionContext extensionContext) {
        if (!extensionContext.getTestClass().isPresent()) {
            return false;
        }
        ExtensionContext.Store springStore = extensionContext.getRoot().getStore(ExtensionContext.Namespace.create((Object[])new Object[]{SpringExtension.class}));
        return springStore != null && springStore.get(extensionContext.getTestClass().get()) != null;
    }

    private boolean isCachedConnection(DataSetExecutor executor) {
        return executor != null && executor.getDBUnitConfig().isCacheConnection() != false;
    }

    public Optional<DataSet> getDataSetFromCallbackMethod(ExtensionContext extensionContext, Class callback) {
        if (extensionContext.getTestClass().isPresent()) {
            Optional<Method> callbackMethodFromSuperclass;
            DataSet dataSet;
            Optional<Method> callbackMethod = this.findCallbackMethod((Class)extensionContext.getTestClass().get(), callback);
            if (callbackMethod.isPresent() && (dataSet = (DataSet)AnnotationUtils.findAnnotation((AnnotatedElement)callbackMethod.get(), DataSet.class).orElse(null)) != null) {
                return Optional.of(dataSet);
            }
            Class testSuperclass = ((Class)extensionContext.getTestClass().get()).getSuperclass();
            if (testSuperclass != null && (callbackMethodFromSuperclass = this.findCallbackMethod(testSuperclass, callback)).isPresent() && (dataSet = (DataSet)AnnotationUtils.findAnnotation((AnnotatedElement)callbackMethodFromSuperclass.get(), DataSet.class).orElse(null)) != null) {
                return Optional.of(dataSet);
            }
        }
        return Optional.empty();
    }

    private Optional<Method> findCallbackMethod(Class testClass, Class callback) {
        return Stream.of(testClass.getMethods()).filter(m -> m.getAnnotation(callback) != null).findFirst();
    }

    public void beforeEach(ExtensionContext extensionContext) throws Exception {
        Optional<DataSet> dataSet = this.getDataSetFromCallbackMethod(extensionContext, BeforeEach.class);
        if (dataSet.isPresent()) {
            this.clearEntityManager();
            String executorId = this.getExecutorId(extensionContext, dataSet.get());
            ConnectionHolder connectionHolder = this.getTestConnection(extensionContext, executorId);
            DataSetExecutorImpl dataSetExecutor = DataSetExecutorImpl.instance((String)executorId, (ConnectionHolder)connectionHolder);
            dataSetExecutor.createDataSet(new DataSetConfig().from(dataSet.get()));
        }
    }

    public void afterEach(ExtensionContext extensionContext) {
        Optional<DataSet> dataSet = this.getDataSetFromCallbackMethod(extensionContext, AfterEach.class);
        if (dataSet.isPresent()) {
            this.clearEntityManager();
            String executorId = this.getExecutorId(extensionContext, dataSet.get());
            ConnectionHolder connectionHolder = this.getTestConnection(extensionContext, executorId);
            DataSetExecutorImpl dataSetExecutor = DataSetExecutorImpl.instance((String)executorId, (ConnectionHolder)connectionHolder);
            dataSetExecutor.createDataSet(new DataSetConfig().from(dataSet.get()));
        }
    }

    private void clearEntityManager() {
        if (EntityManagerProvider.isEntityManagerActive()) {
            EntityManagerProvider.em().clear();
        }
    }
}

