/*
 * Decompiled with CFR 0.152.
 */
package org.rnorth.testcontainers.jdbc;

import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import com.spotify.docker.client.messages.Container;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.nio.charset.Charset;
import java.sql.Connection;
import java.sql.Driver;
import java.sql.DriverManager;
import java.sql.DriverPropertyInfo;
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Properties;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.script.ScriptException;
import org.rnorth.testcontainers.containers.DatabaseContainer;
import org.rnorth.testcontainers.jdbc.ConnectionWrapper;
import org.rnorth.testcontainers.jdbc.ext.ScriptUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ContainerDatabaseDriver
implements Driver {
    public static final Pattern URL_MATCHING_PATTERN = Pattern.compile("jdbc:tc:(mysql|postgresql|oracle)(:([^:]+))?://[^\\?]+(\\?.*)?");
    public static final Pattern INITSCRIPT_MATCHING_PATTERN = Pattern.compile(".*([\\?&]?)TC_INITSCRIPT=([^\\?&]+).*");
    public static final Pattern INITFUNCTION_MATCHING_PATTERN = Pattern.compile(".*([\\?&]?)TC_INITFUNCTION=((\\p{javaJavaIdentifierStart}\\p{javaJavaIdentifierPart}*\\.)*\\p{javaJavaIdentifierStart}\\p{javaJavaIdentifierPart}*)::(\\p{javaJavaIdentifierStart}\\p{javaJavaIdentifierPart}*).*");
    private static final Logger LOGGER = LoggerFactory.getLogger(ContainerDatabaseDriver.class);
    private Driver delegate;
    private Map<Container, Set<Connection>> containerConnections = new HashMap<Container, Set<Connection>>();
    private Map<String, DatabaseContainer> jdbcUrlContainerCache = new HashMap<String, DatabaseContainer>();
    private Set<DatabaseContainer> initializedContainers = new HashSet<DatabaseContainer>();

    private static void load() {
        try {
            DriverManager.registerDriver(new ContainerDatabaseDriver());
        }
        catch (SQLException e) {
            LOGGER.warn("Failed to register driver", (Throwable)e);
        }
    }

    @Override
    public boolean acceptsURL(String url) throws SQLException {
        return url.startsWith("jdbc:tc:");
    }

    @Override
    public synchronized Connection connect(String url, Properties info) throws SQLException {
        String queryString = "";
        DatabaseContainer container = this.jdbcUrlContainerCache.get(url);
        if (container == null) {
            Matcher urlMatcher = URL_MATCHING_PATTERN.matcher(url);
            if (!urlMatcher.matches()) {
                throw new IllegalArgumentException("JDBC URL matches jdbc:tc: prefix but the database or tag name could not be identified");
            }
            String databaseType = urlMatcher.group(1);
            String tag = urlMatcher.group(3);
            queryString = urlMatcher.group(4);
            if (queryString == null) {
                queryString = "";
            }
            ServiceLoader<DatabaseContainer> databaseContainers = ServiceLoader.load(DatabaseContainer.class);
            for (DatabaseContainer candidateContainerType : databaseContainers) {
                if (!candidateContainerType.getName().equals(databaseType)) continue;
                candidateContainerType.setTag(tag);
                this.delegate = this.getDriver(candidateContainerType.getDriverClassName());
                container = candidateContainerType;
            }
            if (container == null) {
                throw new UnsupportedOperationException("Database name " + databaseType + " not supported");
            }
            this.jdbcUrlContainerCache.put(url, container);
            container.start();
        }
        info.put("user", container.getUsername());
        info.put("password", container.getPassword());
        Connection connection = this.delegate.connect(container.getJdbcUrl() + queryString, info);
        if (!this.initializedContainers.contains(container)) {
            this.runInitScriptIfRequired(url, connection);
            this.runInitFunctionIfRequired(url, connection);
            this.initializedContainers.add(container);
        }
        return this.wrapConnection(connection, container, url);
    }

    private Connection wrapConnection(final Connection connection, final DatabaseContainer container, final String url) {
        Set<Connection> connections = this.containerConnections.get(connection);
        if (connections == null) {
            connections = new HashSet<Connection>();
        }
        connections.add(connection);
        final Set<Connection> finalConnections = connections;
        return new ConnectionWrapper(connection, new Runnable(){

            @Override
            public void run() {
                finalConnections.remove(connection);
                if (finalConnections.isEmpty()) {
                    container.stop();
                    ContainerDatabaseDriver.this.jdbcUrlContainerCache.remove(url);
                }
            }
        });
    }

    private void runInitScriptIfRequired(String url, Connection connection) throws SQLException {
        Matcher matcher = INITSCRIPT_MATCHING_PATTERN.matcher(url);
        if (matcher.matches()) {
            String initScriptPath = matcher.group(2);
            try {
                URL resource = Resources.getResource((String)initScriptPath);
                String sql = Resources.toString((URL)resource, (Charset)Charsets.UTF_8);
                ScriptUtils.executeSqlScript(connection, initScriptPath, sql);
            }
            catch (IOException | IllegalArgumentException e) {
                LOGGER.warn("Could not load classpath init script", (Object)initScriptPath);
            }
            catch (ScriptException e) {
                LOGGER.error("Error while executing init script", (Throwable)e);
            }
        }
    }

    private void runInitFunctionIfRequired(String url, Connection connection) throws SQLException {
        Matcher matcher = INITFUNCTION_MATCHING_PATTERN.matcher(url);
        if (matcher.matches()) {
            String className = matcher.group(2);
            String methodName = matcher.group(4);
            try {
                Class<?> initFunctionClazz = Class.forName(className);
                Method method = initFunctionClazz.getMethod(methodName, Connection.class);
                method.invoke(null, connection);
            }
            catch (ClassNotFoundException e) {
                e.printStackTrace();
            }
            catch (NoSuchMethodException e) {
                e.printStackTrace();
            }
            catch (InvocationTargetException e) {
                e.printStackTrace();
            }
            catch (IllegalAccessException e) {
                e.printStackTrace();
            }
        }
    }

    private Driver getDriver(String driverClassName) {
        try {
            return (Driver)ClassLoader.getSystemClassLoader().loadClass(driverClassName).newInstance();
        }
        catch (ClassNotFoundException | IllegalAccessException | InstantiationException e) {
            throw new RuntimeException("Could not get Driver", e);
        }
    }

    @Override
    public DriverPropertyInfo[] getPropertyInfo(String url, Properties info) throws SQLException {
        return this.delegate.getPropertyInfo(url, info);
    }

    @Override
    public int getMajorVersion() {
        return this.delegate.getMajorVersion();
    }

    @Override
    public int getMinorVersion() {
        return this.delegate.getMinorVersion();
    }

    @Override
    public boolean jdbcCompliant() {
        return this.delegate.jdbcCompliant();
    }

    @Override
    public java.util.logging.Logger getParentLogger() throws SQLFeatureNotSupportedException {
        return this.delegate.getParentLogger();
    }

    static {
        ContainerDatabaseDriver.load();
    }
}

