/*
 * Decompiled with CFR 0.152.
 */
package com.github.database.rider.junit5;

import com.github.database.rider.core.api.connection.ConnectionHolder;
import com.github.database.rider.core.api.dataset.DataSet;
import com.github.database.rider.core.api.dataset.DataSetExecutor;
import com.github.database.rider.core.api.dataset.ExpectedDataSet;
import com.github.database.rider.core.api.exporter.DataSetExportConfig;
import com.github.database.rider.core.api.exporter.ExportDataSet;
import com.github.database.rider.core.api.leak.LeakHunter;
import com.github.database.rider.core.configuration.ConnectionConfig;
import com.github.database.rider.core.configuration.DBUnitConfig;
import com.github.database.rider.core.configuration.DataSetConfig;
import com.github.database.rider.core.connection.ConnectionHolderImpl;
import com.github.database.rider.core.dataset.DataSetExecutorImpl;
import com.github.database.rider.core.exporter.DataSetExporter;
import com.github.database.rider.core.leak.LeakHunterException;
import com.github.database.rider.core.leak.LeakHunterFactory;
import com.github.database.rider.core.util.EntityManagerProvider;
import com.github.database.rider.junit5.DBUnitTestContext;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.sql.DriverManager;
import java.util.Arrays;
import java.util.Optional;
import org.dbunit.DatabaseUnitException;
import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DBUnitExtension
implements BeforeTestExecutionCallback,
AfterTestExecutionCallback {
    private static final Logger log = LoggerFactory.getLogger(DBUnitExtension.class);
    private static final ExtensionContext.Namespace namespace = ExtensionContext.Namespace.create((Object[])new Object[]{DBUnitExtension.class});

    public void beforeTestExecution(ExtensionContext testExtensionContext) throws Exception {
        DataSet annotation;
        if (!this.shouldCreateDataSet(testExtensionContext)) {
            return;
        }
        ConnectionHolder connectionHolder = this.findTestConnection(testExtensionContext);
        if (EntityManagerProvider.isEntityManagerActive()) {
            EntityManagerProvider.em().clear();
        }
        if ((annotation = ((Method)testExtensionContext.getTestMethod().get()).getAnnotation(DataSet.class)) == null) {
            annotation = ((Class)testExtensionContext.getTestClass().get()).getAnnotation(DataSet.class);
        }
        DBUnitConfig dbUnitConfig = DBUnitConfig.from((Method)((Method)testExtensionContext.getTestMethod().get()));
        DataSetConfig dataSetConfig = new DataSetConfig().from(annotation);
        if (connectionHolder == null || connectionHolder.getConnection() == null) {
            connectionHolder = this.createConnection(dbUnitConfig, ((Method)testExtensionContext.getTestMethod().get()).getName());
        }
        DataSetExecutorImpl executor = DataSetExecutorImpl.instance((String)dataSetConfig.getExecutorId(), (ConnectionHolder)connectionHolder);
        executor.setDBUnitConfig(dbUnitConfig);
        DBUnitTestContext dbUnitTestContext = this.getTestContext(testExtensionContext);
        dbUnitTestContext.setExecutor((DataSetExecutor)executor).setDataSetConfig(dataSetConfig);
        if (dataSetConfig != null && dataSetConfig.getExecuteStatementsBefore() != null && dataSetConfig.getExecuteStatementsBefore().length > 0) {
            try {
                executor.executeStatements(dataSetConfig.getExecuteStatementsBefore());
            }
            catch (Exception e) {
                log.error(((Method)testExtensionContext.getTestMethod().get()).getName() + "() - Could not execute statements Before:" + e.getMessage(), (Throwable)e);
            }
        }
        if (dataSetConfig.getExecuteScriptsBefore() != null && dataSetConfig.getExecuteScriptsBefore().length > 0) {
            try {
                for (int i = 0; i < dataSetConfig.getExecuteScriptsBefore().length; ++i) {
                    executor.executeScript(dataSetConfig.getExecuteScriptsBefore()[i]);
                }
            }
            catch (Exception e) {
                if (e instanceof DatabaseUnitException) {
                    throw e;
                }
                log.error(((Method)testExtensionContext.getTestMethod().get()).getName() + "() - Could not execute scriptsBefore:" + e.getMessage(), (Throwable)e);
            }
        }
        if (dbUnitConfig.isLeakHunter().booleanValue()) {
            LeakHunter leakHunter = LeakHunterFactory.from((Connection)connectionHolder.getConnection());
            dbUnitTestContext.setLeakHunter(leakHunter).setOpenConnections(leakHunter.openConnections());
        }
        try {
            executor.createDataSet(dataSetConfig);
        }
        catch (Exception e) {
            throw new RuntimeException(String.format("Could not create dataset for test method %s due to following error " + e.getMessage(), ((Method)testExtensionContext.getTestMethod().get()).getName()), e);
        }
        boolean isTransactional = dataSetConfig.isTransactional();
        if (isTransactional) {
            if (EntityManagerProvider.isEntityManagerActive()) {
                if (!EntityManagerProvider.tx().isActive()) {
                    EntityManagerProvider.em().getTransaction().begin();
                }
            } else {
                Connection connection = executor.getRiderDataSource().getConnection();
                connection.setAutoCommit(false);
            }
        }
    }

    private boolean shouldCreateDataSet(ExtensionContext testExtensionContext) {
        return ((Method)testExtensionContext.getTestMethod().get()).isAnnotationPresent(DataSet.class) || ((Class)testExtensionContext.getTestClass().get()).isAnnotationPresent(DataSet.class);
    }

    private boolean shouldCompareDataSet(ExtensionContext testExtensionContext) {
        return ((Method)testExtensionContext.getTestMethod().get()).isAnnotationPresent(ExpectedDataSet.class) || ((Class)testExtensionContext.getTestClass().get()).isAnnotationPresent(ExpectedDataSet.class);
    }

    private boolean shouldExportDataSet(ExtensionContext testExtensionContext) {
        return ((Method)testExtensionContext.getTestMethod().get()).isAnnotationPresent(ExportDataSet.class) || ((Class)testExtensionContext.getTestClass().get()).isAnnotationPresent(ExportDataSet.class);
    }

    public void exportDataSet(DataSetExecutor dataSetExecutor, Method method) {
        ExportDataSet exportDataSet = this.resolveExportDataSet(method);
        if (exportDataSet != null) {
            DataSetExportConfig exportConfig = DataSetExportConfig.from((ExportDataSet)exportDataSet);
            String outputName = exportConfig.getOutputFileName();
            if (outputName == null || "".equals(outputName.trim())) {
                outputName = method.getName().toLowerCase() + "." + exportConfig.getDataSetFormat().name().toLowerCase();
            }
            exportConfig.outputFileName(outputName);
            try {
                DataSetExporter.getInstance().export(dataSetExecutor.getRiderDataSource().getDBUnitConnection(), exportConfig);
            }
            catch (Exception e) {
                log.warn("Could not export dataset after method " + method.getName(), (Throwable)e);
            }
        }
    }

    private ExportDataSet resolveExportDataSet(Method method) {
        ExportDataSet exportDataSet = method.getAnnotation(ExportDataSet.class);
        if (exportDataSet == null) {
            exportDataSet = method.getDeclaringClass().getAnnotation(ExportDataSet.class);
        }
        return exportDataSet;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void afterTestExecution(ExtensionContext testExtensionContext) throws Exception {
        Connection connection;
        DBUnitTestContext dbUnitTestContext = this.getTestContext(testExtensionContext);
        DataSetConfig dataSetConfig = dbUnitTestContext.getDataSetConfig();
        DataSetExecutor executor = dbUnitTestContext.getExecutor();
        DBUnitConfig dbUnitConfig = executor != null ? executor.getDBUnitConfig() : DBUnitConfig.from((Method)((Method)testExtensionContext.getTestMethod().get()));
        boolean isTransactional = dataSetConfig != null && dataSetConfig.isTransactional();
        try {
            if (isTransactional) {
                if (EntityManagerProvider.isEntityManagerActive()) {
                    if (EntityManagerProvider.tx().isActive()) {
                        EntityManagerProvider.tx().commit();
                    }
                } else {
                    connection = executor.getRiderDataSource().getConnection();
                    connection.commit();
                    connection.setAutoCommit(false);
                }
            }
            if (dataSetConfig != null && executor != null && this.shouldCompareDataSet(testExtensionContext)) {
                ExpectedDataSet expectedDataSet = ((Method)testExtensionContext.getTestMethod().get()).getAnnotation(ExpectedDataSet.class);
                if (expectedDataSet == null) {
                    expectedDataSet = ((Class)testExtensionContext.getTestClass().get()).getAnnotation(ExpectedDataSet.class);
                }
                if (expectedDataSet != null) {
                    executor.compareCurrentDataSetWith(new DataSetConfig(expectedDataSet.value()).disableConstraints(true), expectedDataSet.ignoreCols());
                }
            }
            if (dbUnitConfig != null && dbUnitConfig.isLeakHunter().booleanValue()) {
                LeakHunter leakHunter = dbUnitTestContext.getLeakHunter();
                int openConnectionsBefore = dbUnitTestContext.getOpenConnections();
                int openConnectionsAfter = leakHunter.openConnections();
                if (openConnectionsAfter > openConnectionsBefore) {
                    throw new LeakHunterException(((Method)testExtensionContext.getTestMethod().get()).getName(), openConnectionsAfter - openConnectionsBefore);
                }
            }
        }
        finally {
            if (dataSetConfig == null || executor == null) {
                return;
            }
            if (isTransactional) {
                if (EntityManagerProvider.isEntityManagerActive() && EntityManagerProvider.em().getTransaction().isActive()) {
                    EntityManagerProvider.em().getTransaction().rollback();
                } else {
                    connection = executor.getRiderDataSource().getConnection();
                    connection.rollback();
                }
            }
            if (this.shouldExportDataSet(testExtensionContext)) {
                this.exportDataSet(executor, (Method)testExtensionContext.getTestMethod().get());
            }
            if (dataSetConfig.getExecuteStatementsAfter() != null && dataSetConfig.getExecuteStatementsAfter().length > 0) {
                try {
                    executor.executeStatements(dataSetConfig.getExecuteStatementsAfter());
                }
                catch (Exception e) {
                    log.error(((Method)testExtensionContext.getTestMethod().get()).getName() + "() - Could not execute statements after:" + e.getMessage(), (Throwable)e);
                }
            }
            if (dataSetConfig.getExecuteScriptsAfter() != null && dataSetConfig.getExecuteScriptsAfter().length > 0) {
                try {
                    for (int i = 0; i < dataSetConfig.getExecuteScriptsAfter().length; ++i) {
                        executor.executeScript(dataSetConfig.getExecuteScriptsAfter()[i]);
                    }
                }
                catch (Exception e) {
                    if (e instanceof DatabaseUnitException) {
                        throw e;
                    }
                    log.error(((Method)testExtensionContext.getTestMethod().get()).getName() + "() - Could not execute scriptsAfter:" + e.getMessage(), (Throwable)e);
                }
            }
            if (dataSetConfig.isCleanAfter()) {
                executor.clearDatabase(dataSetConfig);
            }
            executor.enableConstraints();
        }
    }

    private ConnectionHolder findTestConnection(ExtensionContext testExtensionContext) {
        Class testClass = (Class)testExtensionContext.getTestClass().get();
        try {
            Optional<Field> fieldFound = Arrays.stream(testClass.getDeclaredFields()).filter(f -> f.getType() == ConnectionHolder.class).findFirst();
            if (fieldFound.isPresent()) {
                ConnectionHolder connectionHolder;
                Field field = fieldFound.get();
                if (!field.isAccessible()) {
                    field.setAccessible(true);
                }
                if ((connectionHolder = (ConnectionHolder)ConnectionHolder.class.cast(field.get(testExtensionContext.getTestInstance().get()))) == null || connectionHolder.getConnection() == null) {
                    throw new RuntimeException("ConnectionHolder not initialized correctly");
                }
                return connectionHolder;
            }
            Optional<Method> methodFound = Arrays.stream(testClass.getDeclaredMethods()).filter(m -> m.getReturnType() == ConnectionHolder.class).findFirst();
            if (methodFound.isPresent()) {
                ConnectionHolder connectionHolder;
                Method method = methodFound.get();
                if (!method.isAccessible()) {
                    method.setAccessible(true);
                }
                if ((connectionHolder = (ConnectionHolder)ConnectionHolder.class.cast(method.invoke(testExtensionContext.getTestInstance().get(), new Object[0]))) == null || connectionHolder == null) {
                    throw new RuntimeException("ConnectionHolder not initialized correctly");
                }
                return connectionHolder;
            }
        }
        catch (Exception e) {
            throw new RuntimeException("Could not get database connection for test " + testClass, e);
        }
        return null;
    }

    private ConnectionHolder createConnection(DBUnitConfig dbUnitConfig, String currentMethod) {
        ConnectionConfig connectionConfig = dbUnitConfig.getConnectionConfig();
        if ("".equals(connectionConfig.getUrl()) || "".equals(connectionConfig.getUser())) {
            throw new RuntimeException(String.format("Could not create JDBC connection for method %s, provide a connection at test level or via configuration, see documentation here: https://github.com/database-rider/database-rider#7-junit-5", currentMethod));
        }
        try {
            if (!"".equals(connectionConfig.getDriver())) {
                Class.forName(connectionConfig.getDriver());
            }
            return new ConnectionHolderImpl(DriverManager.getConnection(connectionConfig.getUrl(), connectionConfig.getUser(), connectionConfig.getPassword()));
        }
        catch (Exception e) {
            log.error("Could not create JDBC connection for method " + currentMethod, (Throwable)e);
            return null;
        }
    }

    private DBUnitTestContext getTestContext(ExtensionContext context) {
        Class testClass = (Class)context.getTestClass().get();
        ExtensionContext.Store store = context.getStore(namespace);
        DBUnitTestContext testContext = (DBUnitTestContext)store.get((Object)testClass, DBUnitTestContext.class);
        if (testContext == null) {
            testContext = new DBUnitTestContext();
            store.put((Object)testClass, (Object)testContext);
        }
        return testContext;
    }
}

