/*
 * Decompiled with CFR 0.152.
 */
package io.goodforgod.testcontainers.extensions;

import io.goodforgod.testcontainers.extensions.ContainerContext;
import io.goodforgod.testcontainers.extensions.ContainerMetadata;
import io.goodforgod.testcontainers.extensions.ContainerMode;
import java.lang.annotation.Annotation;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import org.jetbrains.annotations.ApiStatus;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionConfigurationException;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.jupiter.api.extension.ParameterResolutionException;
import org.junit.jupiter.api.extension.ParameterResolver;
import org.junit.platform.commons.support.AnnotationSupport;
import org.junit.platform.commons.util.ReflectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.Network;

@ApiStatus.Internal
public abstract class AbstractTestcontainersExtension<Connection, Container extends GenericContainer<?>, Metadata extends ContainerMetadata>
implements BeforeAllCallback,
BeforeEachCallback,
AfterAllCallback,
AfterEachCallback,
ParameterResolver {
    static final Map<String, Map<SharedKey, ContainerContext<?>>> CLASS_TO_SHARED_CONTAINERS = new ConcurrentHashMap();
    protected final Logger logger = LoggerFactory.getLogger(this.getClass());

    protected abstract Class<? extends Annotation> getContainerAnnotation();

    protected abstract Class<? extends Annotation> getConnectionAnnotation();

    protected abstract Class<Connection> getConnectionType();

    protected abstract Class<Container> getContainerType();

    protected abstract Optional<Metadata> findMetadata(ExtensionContext var1);

    protected final Metadata getMetadata(ExtensionContext context) {
        return (Metadata)((ContainerMetadata)this.findMetadata(context).orElseThrow(() -> new ExtensionConfigurationException("Extension annotation not found")));
    }

    protected abstract ExtensionContext.Namespace getNamespace();

    protected abstract Container createContainerDefault(Metadata var1);

    protected abstract ContainerContext<Connection> createContainerContext(Container var1);

    protected final ExtensionContext.Store getStorage(ExtensionContext context) {
        if (context.getParent().isPresent() && ((ExtensionContext)context.getParent().get()).getParent().isPresent()) {
            return ((ExtensionContext)context.getParent().get()).getStore(this.getNamespace());
        }
        if (context.getParent().isEmpty()) {
            return context.getStore(this.getNamespace());
        }
        return context.getStore(this.getNamespace());
    }

    protected ContainerContext<Connection> getContainerContext(ExtensionContext context) {
        Metadata metadata = this.getMetadata(context);
        return (ContainerContext)this.getStorage(context).get((Object)metadata.runMode(), ContainerContext.class);
    }

    protected <T extends Annotation> Optional<T> findAnnotation(Class<T> annotationType, ExtensionContext context) {
        Optional current = Optional.of(context);
        while (current.isPresent()) {
            Class requiredClass = current.get().getRequiredTestClass();
            while (!requiredClass.equals(Object.class)) {
                Optional annotation = AnnotationSupport.findAnnotation((AnnotatedElement)requiredClass, annotationType);
                if (annotation.isPresent()) {
                    return annotation;
                }
                requiredClass = requiredClass.getSuperclass();
            }
            current = current.get().getParent();
        }
        return Optional.empty();
    }

    protected Optional<Container> findContainerFromField(ExtensionContext context) {
        this.logger.debug("Looking for {} Container...", (Object)this.getContainerType().getSimpleName());
        if (context.getTestClass().isEmpty() || context.getTestInstance().isEmpty()) {
            return Optional.empty();
        }
        Optional<Container> container = this.findContainerInClassField(context.getTestInstance().get());
        if (container.isPresent()) {
            return container;
        }
        if (context.getTestClass().filter(c -> c.isAnnotationPresent(Nested.class)).isPresent()) {
            return AbstractTestcontainersExtension.findParentTestClassIfNested(context).flatMap(this::findContainerInClassField);
        }
        return Optional.empty();
    }

    private static Optional<Object> findParentTestClassIfNested(ExtensionContext context) {
        if (context.getTestClass().filter(c -> c.isAnnotationPresent(Nested.class)).isPresent()) {
            return context.getTestInstance().flatMap(instance -> AbstractTestcontainersExtension.findParentTestClass(instance.getClass(), context).flatMap(aClass -> Arrays.stream(instance.getClass().getDeclaredFields()).filter(f -> f.getType().equals(aClass)).findFirst().map(f -> {
                try {
                    f.setAccessible(true);
                    return f.get(instance);
                }
                catch (IllegalAccessException e) {
                    throw new IllegalStateException(e);
                }
            })));
        }
        return Optional.empty();
    }

    private Optional<Container> findContainerInClassField(Object testClassInstance) {
        return ReflectionUtils.findFields(testClassInstance.getClass(), f -> !f.isSynthetic() && f.getAnnotation(this.getContainerAnnotation()) != null, (ReflectionUtils.HierarchyTraversalMode)ReflectionUtils.HierarchyTraversalMode.TOP_DOWN).stream().findFirst().map(field -> {
            try {
                field.setAccessible(true);
                Object possibleContainer = field.get(testClassInstance);
                if (this.getContainerType().isAssignableFrom(possibleContainer.getClass())) {
                    this.logger.debug("Found {} Container in field: {}", (Object)this.getContainerType().getSimpleName(), (Object)field.getName());
                    return (GenericContainer)possibleContainer;
                }
                throw new IllegalArgumentException(String.format("Field '%s' annotated with @%s value must be instance of %s", field.getName(), this.getContainerAnnotation().getSimpleName(), this.getContainerType()));
            }
            catch (IllegalAccessException e) {
                throw new IllegalStateException(String.format("Failed retrieving value from field '%s' annotated with @%s", field.getName(), this.getContainerAnnotation().getSimpleName()), e);
            }
        });
    }

    private static Optional<Class<?>> findParentTestClass(Class<?> childTestClass, ExtensionContext context) {
        return context.getTestClass().filter(c -> !c.equals(childTestClass)).or(() -> context.getParent().flatMap(parentContext -> AbstractTestcontainersExtension.findParentTestClass(childTestClass, parentContext)));
    }

    protected void injectContext(ContainerContext<Connection> containerContext, ExtensionContext context) {
        context.getTestInstance().ifPresent(instance -> this.injectContextIntoInstance(containerContext, instance));
        if (context.getTestClass().filter(c -> c.isAnnotationPresent(Nested.class)).isPresent()) {
            AbstractTestcontainersExtension.findParentTestClassIfNested(context).ifPresent(instance -> this.injectContextIntoInstance(containerContext, instance));
        }
    }

    protected void injectContextIntoInstance(ContainerContext<Connection> containerContext, Object testClassInstance) {
        Class<Annotation> connectionAnnotation = this.getConnectionAnnotation();
        List connectionFields = ReflectionUtils.findFields(testClassInstance.getClass(), f -> !f.isSynthetic() && !Modifier.isFinal(f.getModifiers()) && !Modifier.isStatic(f.getModifiers()) && f.getAnnotation(connectionAnnotation) != null, (ReflectionUtils.HierarchyTraversalMode)ReflectionUtils.HierarchyTraversalMode.TOP_DOWN);
        this.logger.debug("Starting field injection for connection: {}", containerContext.connection());
        for (Field field : connectionFields) {
            this.injectContextIntoField(containerContext, field, testClassInstance);
        }
    }

    protected void injectContextIntoField(ContainerContext<Connection> containerContext, Field field, Object testClassInstance) {
        try {
            field.setAccessible(true);
            field.set(testClassInstance, containerContext.connection());
        }
        catch (IllegalAccessException e) {
            throw new IllegalStateException(String.format("Field '%s' annotated with @%s can't set connection", field.getName(), this.getConnectionAnnotation().getSimpleName()), e);
        }
    }

    public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException {
        boolean foundSuitable;
        Class<Annotation> connectionAnnotation = this.getConnectionAnnotation();
        boolean bl = foundSuitable = parameterContext.getParameter().getAnnotation(connectionAnnotation) != null;
        if (!foundSuitable) {
            return false;
        }
        if (!parameterContext.getParameter().getType().equals(this.getConnectionType())) {
            throw new ParameterResolutionException(String.format("Parameter '%s' annotated @%s is not of type %s", parameterContext.getParameter().getName(), connectionAnnotation.getSimpleName(), this.getConnectionType()));
        }
        return true;
    }

    public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException {
        CallMode callMode = this.getCallMode(parameterContext);
        ContainerContext<Connection> containerContext = this.getContainerContext(extensionContext);
        if (containerContext != null) {
            return containerContext.connection();
        }
        Metadata metadata = this.getMetadata(extensionContext);
        if (metadata.runMode() == ContainerMode.PER_RUN || metadata.runMode() == ContainerMode.PER_CLASS) {
            this.beforeAll(extensionContext);
        } else if (metadata.runMode() == ContainerMode.PER_METHOD) {
            TestInstance.Lifecycle lifecycle = extensionContext.getTestInstanceLifecycle().orElse(TestInstance.Lifecycle.PER_METHOD);
            if (callMode == CallMode.CONSTRUCTOR && lifecycle == TestInstance.Lifecycle.PER_CLASS) {
                throw new ParameterResolutionException(String.format("@%s can't be injected into constructor parameter when ContainerMode.%s is used and lifecycle is @%s", new Object[]{this.getConnectionAnnotation().getSimpleName(), ContainerMode.PER_METHOD, TestInstance.Lifecycle.PER_CLASS}));
            }
            if (callMode == CallMode.BEFORE_ALL) {
                throw new ParameterResolutionException(String.format("@%s can't be injected into @%s method parameter when ContainerMode.%s is used", new Object[]{this.getConnectionAnnotation().getSimpleName(), BeforeAll.class.getSimpleName(), ContainerMode.PER_METHOD}));
            }
            this.beforeEach(extensionContext);
        }
        return Optional.ofNullable(this.getContainerContext(extensionContext).connection()).orElseThrow(() -> new ParameterResolutionException(String.format("Parameter named '%s' with type '%s' can't be resolved cause it probably isn't initialized yet, please check extension annotation execution order", parameterContext.getParameter().getName(), this.getConnectionType())));
    }

    private CallMode getCallMode(ParameterContext parameterContext) {
        if (parameterContext.getDeclaringExecutable().isAnnotationPresent(BeforeAll.class)) {
            return CallMode.BEFORE_ALL;
        }
        if (parameterContext.getDeclaringExecutable().isAnnotationPresent(BeforeEach.class)) {
            return CallMode.BEFORE_EACH;
        }
        if (parameterContext.getDeclaringExecutable().getDeclaringClass().getName().equals(parameterContext.getDeclaringExecutable().getName())) {
            return CallMode.CONSTRUCTOR;
        }
        return null;
    }

    public void beforeAll(ExtensionContext context) {
        Object metadata = this.getMetadata(context);
        ExtensionContext.Store storage = this.getStorage(context);
        if (this.getContainerContext(context) == null) {
            if (metadata.runMode() == ContainerMode.PER_RUN) {
                Optional containerFromField = this.findContainerFromField(context);
                SharedKey sharedKey = containerFromField.map(c -> new SharedContainerInstance((GenericContainer<?>)c)).orElseGet(() -> {
                    String imageShared = metadata.image();
                    Boolean networkShared = containerFromField.filter(c -> c.getNetwork() != null).map(c -> c.getNetwork() == Network.SHARED).orElse(metadata.networkShared());
                    String networkAlias = containerFromField.map(c -> c.getNetworkAliases()).filter(a -> !a.isEmpty()).map(a -> a.stream().filter(alias -> alias.equals(metadata.networkAlias())).findFirst().orElse((String)a.get(0))).orElse(metadata.networkAlias());
                    return new SharedContainerKey(imageShared, networkShared, networkAlias);
                });
                Map sharedContainerMap = CLASS_TO_SHARED_CONTAINERS.computeIfAbsent(this.getClass().getCanonicalName(), k -> new ConcurrentHashMap());
                ContainerContext containerContext = sharedContainerMap.computeIfAbsent(sharedKey, k -> {
                    GenericContainer container = containerFromField.orElseGet(() -> {
                        this.logger.debug("Getting default container for image: {}", (Object)metadata.image());
                        return this.createContainerDefault(metadata);
                    });
                    container.withReuse(true);
                    ContainerContext<Connection> conContext = this.createContainerContext(container);
                    this.logger.debug("Starting in mode '{}' container: {}", (Object)metadata.runMode(), conContext);
                    conContext.start();
                    this.logger.info("Started in mode '{}' container: {}", (Object)metadata.runMode(), conContext);
                    return conContext;
                });
                storage.put((Object)metadata.runMode(), (Object)containerContext);
                this.injectContext(containerContext, context);
            } else if (metadata.runMode() == ContainerMode.PER_CLASS) {
                GenericContainer container = this.findContainerFromField(context).orElseGet(() -> {
                    this.logger.debug("Getting default container for image: {}", (Object)metadata.image());
                    return this.createContainerDefault(metadata);
                });
                ContainerContext<Connection> containerContext = this.createContainerContext(container);
                this.logger.debug("Starting in mode '{}' container: {}", (Object)metadata.runMode(), containerContext);
                containerContext.start();
                this.logger.info("Started in mode '{}' container: {}", (Object)metadata.runMode(), containerContext);
                storage.put((Object)metadata.runMode(), containerContext);
                this.injectContext(containerContext, context);
            }
        }
    }

    public void beforeEach(ExtensionContext context) {
        TestInstance.Lifecycle lifecycle;
        ContainerContext<Connection> containerContext;
        Metadata metadata = this.getMetadata(context);
        ExtensionContext.Store storage = this.getStorage(context);
        if (this.getContainerContext(context) == null && metadata.runMode() == ContainerMode.PER_METHOD) {
            GenericContainer container = this.findContainerFromField(context).orElseGet(() -> {
                this.logger.debug("Getting default container for image: {}", (Object)metadata.image());
                return this.createContainerDefault(metadata);
            });
            containerContext = this.createContainerContext(container);
            this.logger.debug("Starting in mode '{}' container: {}", (Object)metadata.runMode(), containerContext);
            container.start();
            this.logger.info("Started in mode '{}' container: {}", (Object)metadata.runMode(), containerContext);
            storage.put((Object)metadata.runMode(), containerContext);
        }
        if ((lifecycle = context.getTestInstanceLifecycle().orElse(TestInstance.Lifecycle.PER_METHOD)) == TestInstance.Lifecycle.PER_METHOD) {
            containerContext = this.getContainerContext(context);
            if (containerContext != null) {
                this.injectContext(containerContext, context);
            }
        } else if (metadata.runMode() == ContainerMode.PER_METHOD && (containerContext = this.getContainerContext(context)) != null) {
            this.injectContext(containerContext, context);
        }
    }

    public void afterEach(ExtensionContext context) {
        Metadata metadata = this.getMetadata(context);
        if (metadata.runMode() == ContainerMode.PER_METHOD) {
            ExtensionContext.Store storage = this.getStorage(context);
            ContainerContext<Connection> containerContext = this.getContainerContext(context);
            if (containerContext != null) {
                this.logger.debug("Stopping in mode '{}' container: {}", (Object)metadata.runMode(), containerContext);
                containerContext.stop();
                this.logger.info("Stopped in mode '{}' container: {}", (Object)metadata.runMode(), containerContext);
                storage.remove(this.getConnectionType());
                storage.remove((Object)metadata.runMode());
            }
        }
    }

    public void afterAll(ExtensionContext context) {
        ContainerContext<Connection> containerContext;
        Metadata metadata = this.getMetadata(context);
        if (metadata.runMode() == ContainerMode.PER_CLASS && (containerContext = this.getContainerContext(context)) != null) {
            this.logger.debug("Stopping in mode '{}' container: {}", (Object)metadata.runMode(), containerContext);
            containerContext.stop();
            this.logger.info("Stopped in mode '{}' container: {}", (Object)metadata.runMode(), containerContext);
        }
    }

    public static enum CallMode {
        CONSTRUCTOR,
        BEFORE_EACH,
        BEFORE_ALL;

    }

    static interface SharedKey {
    }

    static final class SharedContainerKey
    implements SharedKey {
        private final String image;
        private final boolean network;
        private final String alias;

        SharedContainerKey(String image, boolean network, String alias) {
            this.image = image;
            this.network = network;
            this.alias = alias;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            SharedContainerKey sharedKey = (SharedContainerKey)o;
            return this.network == sharedKey.network && Objects.equals(this.image, sharedKey.image) && Objects.equals(this.alias, sharedKey.alias);
        }

        public int hashCode() {
            return Objects.hash(this.image, this.network, this.alias);
        }

        public String toString() {
            return this.alias == null ? "[image=" + this.image + "]" : "[image=" + this.image + ", alias=" + this.alias + "]";
        }
    }

    static final class SharedContainerInstance
    implements SharedKey {
        private final GenericContainer<?> container;

        public SharedContainerInstance(GenericContainer<?> container) {
            this.container = container;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            SharedContainerInstance sharedKey = (SharedContainerInstance)o;
            return this.container == sharedKey.container;
        }

        public int hashCode() {
            return Objects.hash(this.container);
        }
    }
}

