/*
 * Decompiled with CFR 0.152.
 */
package com.github.database.rider.junit5;

import com.github.database.rider.core.RiderRunner;
import com.github.database.rider.core.RiderTestContext;
import com.github.database.rider.core.api.configuration.DBUnit;
import com.github.database.rider.core.api.configuration.DataSetMergingStrategy;
import com.github.database.rider.core.api.connection.ConnectionHolder;
import com.github.database.rider.core.api.dataset.DataSet;
import com.github.database.rider.core.api.dataset.DataSetExecutor;
import com.github.database.rider.core.api.dataset.ExpectedDataSet;
import com.github.database.rider.core.api.leak.LeakHunter;
import com.github.database.rider.core.configuration.DBUnitConfig;
import com.github.database.rider.core.configuration.DataSetConfig;
import com.github.database.rider.core.connection.RiderDataSource;
import com.github.database.rider.core.dataset.DataSetExecutorImpl;
import com.github.database.rider.core.leak.LeakHunterFactory;
import com.github.database.rider.junit5.DBUnitTestContext;
import com.github.database.rider.junit5.JUnit5RiderTestContext;
import com.github.database.rider.junit5.jdbc.ConnectionManager;
import com.github.database.rider.junit5.util.Constants;
import com.github.database.rider.junit5.util.EntityManagerProvider;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.sql.SQLException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import org.dbunit.DatabaseUnitException;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.platform.commons.util.AnnotationUtils;
import org.junit.platform.commons.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DBUnitExtension
implements BeforeTestExecutionCallback,
AfterTestExecutionCallback,
BeforeEachCallback,
AfterEachCallback,
BeforeAllCallback,
AfterAllCallback {
    private static final Logger LOG = LoggerFactory.getLogger((String)DBUnitExtension.class.getName());

    public void beforeTestExecution(ExtensionContext extensionContext) throws Exception {
        EntityManagerProvider.clear();
        DBUnitTestContext dbUnitTestContext = this.getTestContext(extensionContext);
        DataSetExecutor dataSetExecutor = dbUnitTestContext.getExecutor();
        DBUnitConfig dbUnitConfig = this.resolveDbUnitConfig(Optional.empty(), extensionContext.getTestMethod(), extensionContext.getRequiredTestClass());
        dataSetExecutor.setDBUnitConfig(dbUnitConfig);
        if (dbUnitConfig.isLeakHunter().booleanValue()) {
            try {
                LeakHunter leakHunter = LeakHunterFactory.from((RiderDataSource)dataSetExecutor.getRiderDataSource(), (String)extensionContext.getRequiredTestMethod().getName(), (boolean)dbUnitConfig.isCacheConnection());
                leakHunter.measureConnectionsBeforeExecution();
                dbUnitTestContext.setLeakHunter(leakHunter);
            }
            catch (SQLException e) {
                LOG.warn(String.format("Could not create leak hunter for test %s", extensionContext.getRequiredTestMethod().getName()), (Throwable)e);
            }
        }
        JUnit5RiderTestContext riderTestContext = new JUnit5RiderTestContext(dbUnitTestContext.getExecutor(), extensionContext);
        RiderRunner riderRunner = new RiderRunner();
        riderRunner.setup((RiderTestContext)riderTestContext);
        riderRunner.runBeforeTest((RiderTestContext)riderTestContext);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void afterTestExecution(ExtensionContext extensionContext) throws Exception {
        DBUnitTestContext dbUnitTestContext = this.getTestContext(extensionContext);
        DBUnitConfig dbUnitConfig = dbUnitTestContext.getExecutor().getDBUnitConfig();
        JUnit5RiderTestContext riderTestContext = new JUnit5RiderTestContext(dbUnitTestContext.getExecutor(), extensionContext);
        RiderRunner riderRunner = new RiderRunner();
        try {
            riderRunner.runAfterTest((RiderTestContext)riderTestContext);
            if (dbUnitConfig != null && dbUnitConfig.isLeakHunter().booleanValue()) {
                LeakHunter leakHunter = dbUnitTestContext.getLeakHunter();
                leakHunter.checkConnectionsAfterExecution();
            }
        }
        finally {
            riderRunner.teardown((RiderTestContext)riderTestContext);
        }
    }

    private DBUnitTestContext getTestContext(ExtensionContext context) {
        Class testClass = context.getRequiredTestClass();
        ExtensionContext.Store store = context.getStore(Constants.NAMESPACE);
        return (DBUnitTestContext)store.getOrComputeIfAbsent((Object)testClass, tc -> this.createDBUnitTestContext(context), DBUnitTestContext.class);
    }

    private DBUnitTestContext createDBUnitTestContext(ExtensionContext extensionContext) {
        String executorId = this.getExecutorId(extensionContext, null);
        ConnectionHolder connectionHolder = ConnectionManager.getTestConnection(extensionContext, executorId);
        DataSetExecutorImpl dataSetExecutor = DataSetExecutorImpl.instance((String)executorId, (ConnectionHolder)connectionHolder);
        return new DBUnitTestContext((DataSetExecutor)dataSetExecutor);
    }

    private Set<Method> findCallbackMethods(Class testClass, Class callback) {
        HashSet methods = new HashSet();
        Stream.of(testClass.getSuperclass().getDeclaredMethods(), testClass.getDeclaredMethods()).flatMap(Stream::of).filter(m -> m.getAnnotation(callback) != null).forEach(m -> methods.add(m));
        return Collections.unmodifiableSet(methods);
    }

    public void beforeEach(ExtensionContext extensionContext) throws Exception {
        Set<Method> callbackMethods;
        if (extensionContext.getTestClass().isPresent() && !(callbackMethods = this.findCallbackMethods((Class)extensionContext.getTestClass().get(), BeforeEach.class)).isEmpty()) {
            for (Method callbackMethod : callbackMethods) {
                this.executeDataSetForCallback(extensionContext, BeforeEach.class, callbackMethod);
                this.executeExpectedDataSetForCallback(extensionContext, BeforeEach.class, callbackMethod);
            }
        }
    }

    public void afterEach(ExtensionContext extensionContext) throws Exception {
        Set<Method> callbackMethods;
        if (extensionContext.getTestClass().isPresent() && !(callbackMethods = this.findCallbackMethods((Class)extensionContext.getTestClass().get(), AfterEach.class)).isEmpty()) {
            for (Method callbackMethod : callbackMethods) {
                this.executeDataSetForCallback(extensionContext, AfterEach.class, callbackMethod);
                this.executeExpectedDataSetForCallback(extensionContext, AfterEach.class, callbackMethod);
            }
        }
    }

    public void beforeAll(ExtensionContext extensionContext) throws Exception {
        Set<Method> callbackMethods;
        if (extensionContext.getTestClass().isPresent() && !(callbackMethods = this.findCallbackMethods((Class)extensionContext.getTestClass().get(), BeforeAll.class)).isEmpty()) {
            for (Method callbackMethod : callbackMethods) {
                this.executeDataSetForCallback(extensionContext, BeforeAll.class, callbackMethod);
                this.executeExpectedDataSetForCallback(extensionContext, BeforeAll.class, callbackMethod);
            }
        }
    }

    public void afterAll(ExtensionContext extensionContext) throws Exception {
        Set<Method> callbackMethods;
        if (extensionContext.getTestClass().isPresent() && !(callbackMethods = this.findCallbackMethods((Class)extensionContext.getTestClass().get(), AfterAll.class)).isEmpty()) {
            for (Method callbackMethod : callbackMethods) {
                this.executeDataSetForCallback(extensionContext, AfterAll.class, callbackMethod);
                this.executeExpectedDataSetForCallback(extensionContext, AfterAll.class, callbackMethod);
            }
        }
    }

    private void executeDataSetForCallback(ExtensionContext extensionContext, Class callbackAnnotation, Method callbackMethod) throws SQLException {
        DataSet dataSet;
        Class testClass = (Class)extensionContext.getTestClass().get();
        Optional dataSetAnnotation = AnnotationUtils.findAnnotation((AnnotatedElement)callbackMethod, DataSet.class);
        if (!dataSetAnnotation.isPresent()) {
            LOG.warn("Could not find dataset annotation from callback method: " + callbackMethod);
            return;
        }
        EntityManagerProvider.clear();
        DBUnitTestContext dbUnitTestContext = this.getTestContext(extensionContext);
        DBUnitConfig dbUnitConfig = this.resolveDbUnitConfig(Optional.of(callbackAnnotation), Optional.of(callbackMethod), testClass);
        if (dbUnitConfig.isMergeDataSets().booleanValue()) {
            Optional classLevelDataSetAnnotation = AnnotationUtils.findAnnotation((AnnotatedElement)testClass, DataSet.class);
            dataSet = this.resolveDataSet(dataSetAnnotation, classLevelDataSetAnnotation, dbUnitConfig);
        } else {
            dataSet = (DataSet)dataSetAnnotation.get();
        }
        DataSetExecutor dataSetExecutor = dbUnitTestContext.getExecutor();
        dataSetExecutor.setDBUnitConfig(dbUnitConfig);
        dataSetExecutor = this.resetExecutorConnectionIfNeeded(extensionContext, callbackAnnotation, dbUnitConfig, dataSetExecutor);
        dataSetExecutor.createDataSet(new DataSetConfig().from(dataSet));
        this.closeConnectionForAfterCallback(dataSetExecutor, callbackAnnotation);
    }

    private void closeConnectionForAfterCallback(DataSetExecutor dataSetExecutor, Class callbackAnnotation) throws SQLException {
        if (!this.isAfterTestCallback(callbackAnnotation)) {
            return;
        }
        if (!dataSetExecutor.getDBUnitConfig().isCacheConnection().booleanValue() && !dataSetExecutor.getRiderDataSource().getDBUnitConnection().getConnection().isClosed()) {
            dataSetExecutor.getRiderDataSource().getDBUnitConnection().getConnection().close();
            ((DataSetExecutorImpl)dataSetExecutor).clearRiderDataSource();
        }
    }

    private void executeExpectedDataSetForCallback(ExtensionContext extensionContext, Class callbackAnnotation, Method callbackMethod) throws DatabaseUnitException, SQLException {
        Class testClass = (Class)extensionContext.getTestClass().get();
        Optional expectedDataSetAnnotation = AnnotationUtils.findAnnotation((AnnotatedElement)callbackMethod, ExpectedDataSet.class);
        if (!expectedDataSetAnnotation.isPresent()) {
            LOG.warn("Could not find expectedDataSet annotation annotation from callback method: " + callbackMethod);
            return;
        }
        ExpectedDataSet expectedDataSet = (ExpectedDataSet)expectedDataSetAnnotation.get();
        DBUnitConfig dbUnitConfig = this.resolveDbUnitConfig(Optional.of(callbackAnnotation), Optional.of(callbackMethod), testClass);
        DataSetExecutor dataSetExecutor = this.getTestContext(extensionContext).getExecutor();
        dataSetExecutor.setDBUnitConfig(dbUnitConfig);
        dataSetExecutor = this.resetExecutorConnectionIfNeeded(extensionContext, callbackAnnotation, dbUnitConfig, dataSetExecutor);
        dataSetExecutor.compareCurrentDataSetWith(new DataSetConfig(expectedDataSet.value()).disableConstraints(true).datasetProvider(expectedDataSet.provider()), expectedDataSet.ignoreCols(), expectedDataSet.replacers(), expectedDataSet.orderBy(), expectedDataSet.compareOperation());
        this.closeConnectionForAfterCallback(dataSetExecutor, callbackAnnotation);
    }

    private DataSetExecutor resetExecutorConnectionIfNeeded(ExtensionContext extensionContext, Class callbackAnnotation, DBUnitConfig dbUnitConfig, DataSetExecutor dataSetExecutor) {
        if (!dbUnitConfig.isCacheConnection().booleanValue() && this.isAfterTestCallback(callbackAnnotation)) {
            ConnectionHolder connectionHolder = ConnectionManager.getTestConnection(extensionContext, dataSetExecutor.getExecutorId());
            dataSetExecutor = DataSetExecutorImpl.instance((String)dataSetExecutor.getExecutorId(), (ConnectionHolder)connectionHolder, (DBUnitConfig)dbUnitConfig);
        }
        return dataSetExecutor;
    }

    private boolean isAfterTestCallback(Class callbackAnnotation) {
        return callbackAnnotation.equals(AfterEach.class) || callbackAnnotation.equals(AfterAll.class);
    }

    private DBUnitConfig resolveDbUnitConfig(Optional<Class> callbackAnnotation, Optional<Method> method, Class testClass) {
        Set<Method> callbackMethods;
        Optional dbUnitAnnotation = AnnotationUtils.findAnnotation(method, DBUnit.class);
        if (!dbUnitAnnotation.isPresent()) {
            dbUnitAnnotation = AnnotationUtils.findAnnotation((AnnotatedElement)testClass, DBUnit.class);
        }
        if (!dbUnitAnnotation.isPresent() && callbackAnnotation.isPresent() && !(callbackMethods = this.findCallbackMethods(testClass, callbackAnnotation.get())).isEmpty()) {
            dbUnitAnnotation = AnnotationUtils.findAnnotation((AnnotatedElement)callbackMethods.iterator().next(), DBUnit.class);
        }
        if (!dbUnitAnnotation.isPresent() && testClass.getSuperclass() != null) {
            dbUnitAnnotation = AnnotationUtils.findAnnotation(testClass.getSuperclass(), DBUnit.class);
        }
        return dbUnitAnnotation.isPresent() ? DBUnitConfig.from((DBUnit)((DBUnit)dbUnitAnnotation.get())) : DBUnitConfig.fromGlobalConfig();
    }

    private DataSet resolveDataSet(Optional<DataSet> methodLevelDataSet, Optional<DataSet> classLevelDataSet, DBUnitConfig config) {
        if (classLevelDataSet.isPresent()) {
            if (DataSetMergingStrategy.METHOD.equals((Object)config.getMergingStrategy())) {
                return com.github.database.rider.core.util.AnnotationUtils.mergeDataSetAnnotations((DataSet)classLevelDataSet.get(), (DataSet)methodLevelDataSet.get());
            }
            return com.github.database.rider.core.util.AnnotationUtils.mergeDataSetAnnotations((DataSet)methodLevelDataSet.get(), (DataSet)classLevelDataSet.get());
        }
        return methodLevelDataSet.get();
    }

    private String getExecutorId(ExtensionContext extensionContext, DataSet dataSet) {
        Optional<DataSet> annDataSet = dataSet != null ? Optional.of(dataSet) : this.findDataSetAnnotation(extensionContext);
        String dataSourceBeanName = ConnectionManager.getConfiguredDataSourceBeanName(extensionContext);
        String executionIdSuffix = dataSourceBeanName.isEmpty() ? "" : "-" + dataSourceBeanName;
        return annDataSet.map(DataSet::executorId).filter(StringUtils::isNotBlank).map(id -> id + executionIdSuffix).orElseGet(() -> "junit5" + executionIdSuffix);
    }

    private Optional<DataSet> findDataSetAnnotation(ExtensionContext extensionContext) {
        Optional testMethod = extensionContext.getTestMethod();
        if (testMethod.isPresent()) {
            Optional annDataSet = AnnotationUtils.findAnnotation((AnnotatedElement)((AnnotatedElement)testMethod.get()), DataSet.class);
            if (!annDataSet.isPresent()) {
                annDataSet = AnnotationUtils.findAnnotation((AnnotatedElement)extensionContext.getRequiredTestClass(), DataSet.class);
            }
            return annDataSet;
        }
        return Optional.empty();
    }
}

