/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.jdbc;

import com.google.common.base.Verify;
import io.trino.Session;
import io.trino.plugin.jdbc.ConnectionFactory;
import io.trino.plugin.jdbc.DriverConnectionFactory;
import io.trino.plugin.jdbc.ForwardingConnection;
import io.trino.spi.connector.ConnectorSession;
import io.trino.testing.AbstractTestQueryFramework;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import org.assertj.core.api.Assertions;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;

public abstract class BaseJdbcConnectionCreationTest
extends AbstractTestQueryFramework {
    protected ConnectionCountingConnectionFactory connectionFactory;

    @BeforeAll
    public void verifySetup() {
        Objects.requireNonNull(this.connectionFactory, "connectionFactory is null");
        this.connectionFactory.assertThatNoConnectionHasLeaked();
    }

    @AfterAll
    public void destroy() throws Exception {
        this.connectionFactory.close();
        this.connectionFactory = null;
    }

    protected void assertJdbcConnections(@Language(value="SQL") String query, int expectedJdbcConnectionsCount, Optional<String> errorMessage) {
        this.assertJdbcConnections(this.getSession(), query, expectedJdbcConnectionsCount, errorMessage);
    }

    protected void assertJdbcConnections(Session session, @Language(value="SQL") String query, int expectedJdbcConnectionsCount, Optional<String> errorMessage) {
        int before = this.connectionFactory.openConnections.get();
        if (errorMessage.isPresent()) {
            this.assertQueryFails(query, errorMessage.get());
        } else {
            Session querySession = Session.builder((Session)session).setSystemProperty("task_max_writer_count", "4").build();
            this.getQueryRunner().execute(querySession, query);
        }
        int after = this.connectionFactory.openConnections.get();
        Assertions.assertThat((int)(after - before)).isEqualTo(expectedJdbcConnectionsCount);
        this.connectionFactory.assertThatNoConnectionHasLeaked();
    }

    protected static class ConnectionCountingConnectionFactory
    implements ConnectionFactory {
        private final Map<Connection, Exception> connectionCreations = Collections.synchronizedMap(new IdentityHashMap());
        private final AtomicInteger openConnections = new AtomicInteger();
        private final ConnectionFactory delegate;

        public ConnectionCountingConnectionFactory(DriverConnectionFactory delegate) {
            this.delegate = (ConnectionFactory)Objects.requireNonNull(delegate, "delegate is null");
        }

        public Connection openConnection(ConnectorSession session) throws SQLException {
            this.openConnections.incrementAndGet();
            final Connection connection = this.delegate.openConnection(session);
            Exception previous = this.connectionCreations.put(connection, new Exception("STACKTRACE"));
            if (previous != null) {
                IllegalStateException exception = new IllegalStateException("Two connections are opened for same session");
                exception.addSuppressed(previous);
                throw exception;
            }
            return new ForwardingConnection(this){
                private volatile boolean closed;
                final /* synthetic */ ConnectionCountingConnectionFactory this$0;
                {
                    this.this$0 = this$0;
                }

                protected Connection delegate() {
                    return connection;
                }

                public void close() throws SQLException {
                    if (this.closed) {
                        return;
                    }
                    this.closed = true;
                    Verify.verify((this.this$0.connectionCreations.remove(connection) != null ? 1 : 0) != 0, (String)("Connection was not created with ConnectionCountingConnectionFactory: " + String.valueOf(connection)), (Object[])new Object[0]);
                    super.close();
                }
            };
        }

        private void assertThatNoConnectionHasLeaked() {
            if (!this.connectionCreations.isEmpty()) {
                AssertionError error = new AssertionError((Object)"%s connections leaked, see attached places".formatted(this.connectionCreations.size()));
                this.connectionCreations.values().forEach(arg_0 -> error.addSuppressed(arg_0));
                throw error;
            }
        }

        public void close() throws SQLException {
            this.delegate.close();
        }
    }
}

