/*
 * Decompiled with CFR 0.152.
 */
package com.github.database.rider.junit5;

import com.github.database.rider.core.RiderRunner;
import com.github.database.rider.core.RiderTestContext;
import com.github.database.rider.core.api.connection.ConnectionHolder;
import com.github.database.rider.core.api.dataset.DataSet;
import com.github.database.rider.core.api.dataset.DataSetExecutor;
import com.github.database.rider.core.api.leak.LeakHunter;
import com.github.database.rider.core.configuration.DBUnitConfig;
import com.github.database.rider.core.connection.RiderDataSource;
import com.github.database.rider.core.dataset.DataSetExecutorImpl;
import com.github.database.rider.core.leak.LeakHunterFactory;
import com.github.database.rider.core.util.EntityManagerProvider;
import com.github.database.rider.junit5.DBUnitTestContext;
import com.github.database.rider.junit5.JUnit5RiderTestContext;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Optional;
import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.platform.commons.util.AnnotationUtils;

public class DBUnitExtension
implements BeforeTestExecutionCallback,
AfterTestExecutionCallback {
    private static final ExtensionContext.Namespace namespace = ExtensionContext.Namespace.create((Object[])new Object[]{DBUnitExtension.class});

    public void beforeTestExecution(ExtensionContext extensionContext) throws Exception {
        DataSet dataSet;
        ConnectionHolder connectionHolder = this.findTestConnection(extensionContext);
        if (EntityManagerProvider.isEntityManagerActive()) {
            EntityManagerProvider.em().clear();
        }
        if ((dataSet = (DataSet)AnnotationUtils.findAnnotation((AnnotatedElement)extensionContext.getRequiredTestMethod(), DataSet.class).orElse(null)) == null) {
            dataSet = AnnotationUtils.findAnnotation((AnnotatedElement)extensionContext.getRequiredTestClass(), DataSet.class).orElse(null);
        }
        DataSetExecutorImpl executor = dataSet == null ? DataSetExecutorImpl.instance((String)"default", (ConnectionHolder)connectionHolder) : DataSetExecutorImpl.instance((String)dataSet.executorId(), (ConnectionHolder)connectionHolder);
        DBUnitTestContext dbUnitTestContext = this.getTestContext(extensionContext);
        dbUnitTestContext.setExecutor((DataSetExecutor)executor);
        JUnit5RiderTestContext riderTestContext = new JUnit5RiderTestContext(dbUnitTestContext.getExecutor(), extensionContext);
        RiderRunner riderRunner = new RiderRunner();
        riderRunner.setup((RiderTestContext)riderTestContext);
        riderRunner.runBeforeTest((RiderTestContext)riderTestContext);
        DBUnitConfig dbUnitConfig = riderTestContext.getDataSetExecutor().getDBUnitConfig();
        if (dbUnitConfig.isLeakHunter().booleanValue()) {
            LeakHunter leakHunter = LeakHunterFactory.from((RiderDataSource)executor.getRiderDataSource(), (String)extensionContext.getRequiredTestMethod().getName());
            leakHunter.measureConnectionsBeforeExecution();
            dbUnitTestContext.setLeakHunter(leakHunter);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void afterTestExecution(ExtensionContext extensionContext) throws Exception {
        DBUnitTestContext dbUnitTestContext = this.getTestContext(extensionContext);
        DBUnitConfig dbUnitConfig = dbUnitTestContext.getExecutor().getDBUnitConfig();
        JUnit5RiderTestContext riderTestContext = new JUnit5RiderTestContext(dbUnitTestContext.getExecutor(), extensionContext);
        RiderRunner riderRunner = new RiderRunner();
        try {
            if (dbUnitConfig != null && dbUnitConfig.isLeakHunter().booleanValue()) {
                LeakHunter leakHunter = dbUnitTestContext.getLeakHunter();
                leakHunter.checkConnectionsAfterExecution();
            }
            riderRunner.runAfterTest((RiderTestContext)riderTestContext);
        }
        finally {
            riderRunner.teardown((RiderTestContext)riderTestContext);
        }
    }

    private ConnectionHolder findTestConnection(ExtensionContext extensionContext) {
        Class testClass = extensionContext.getRequiredTestClass();
        try {
            Optional<Field> fieldFound = Arrays.stream(testClass.getDeclaredFields()).filter(f -> f.getType() == ConnectionHolder.class).findFirst();
            if (fieldFound.isPresent()) {
                ConnectionHolder connectionHolder;
                Field field = fieldFound.get();
                if (!field.isAccessible()) {
                    field.setAccessible(true);
                }
                if ((connectionHolder = (ConnectionHolder)field.get(extensionContext.getRequiredTestInstance())) == null) {
                    throw new RuntimeException("ConnectionHolder not initialized correctly");
                }
                return connectionHolder;
            }
            Optional<Method> methodFound = Arrays.stream(testClass.getDeclaredMethods()).filter(m -> m.getReturnType() == ConnectionHolder.class).findFirst();
            if (methodFound.isPresent()) {
                ConnectionHolder connectionHolder;
                Method method = methodFound.get();
                if (!method.isAccessible()) {
                    method.setAccessible(true);
                }
                if ((connectionHolder = (ConnectionHolder)method.invoke(extensionContext.getRequiredTestInstance(), new Object[0])) == null) {
                    throw new RuntimeException("ConnectionHolder not initialized correctly");
                }
                return connectionHolder;
            }
        }
        catch (Exception e) {
            throw new RuntimeException("Could not get database connection for test " + testClass, e);
        }
        return null;
    }

    private DBUnitTestContext getTestContext(ExtensionContext context) {
        Class testClass = context.getRequiredTestClass();
        ExtensionContext.Store store = context.getStore(namespace);
        return (DBUnitTestContext)store.getOrComputeIfAbsent((Object)testClass, tc -> new DBUnitTestContext(), DBUnitTestContext.class);
    }
}

