/*
 * Decompiled with CFR 0.152.
 */
package io.trino.jdbc;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.Resources;
import com.google.inject.Binder;
import com.google.inject.Inject;
import com.google.inject.Key;
import com.google.inject.Module;
import com.google.inject.Scopes;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.airlift.jaxrs.JaxrsBinder;
import io.airlift.log.Logging;
import io.airlift.testing.Closeables;
import io.trino.client.ClientException;
import io.trino.client.auth.external.DesktopBrowserRedirectHandler;
import io.trino.client.auth.external.RedirectException;
import io.trino.client.auth.external.RedirectHandler;
import io.trino.jdbc.TrinoDriverUri;
import io.trino.plugin.tpch.TpchPlugin;
import io.trino.server.security.AuthenticationException;
import io.trino.server.security.Authenticator;
import io.trino.server.security.ResourceSecurity;
import io.trino.server.security.ServerSecurityModule;
import io.trino.server.testing.TestingTrinoServer;
import io.trino.spi.Plugin;
import io.trino.spi.security.Identity;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.sql.Statement;
import java.time.Duration;
import java.util.ConcurrentModificationException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Properties;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.IntSupplier;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

@Test(singleThreaded=true)
public class TestJdbcExternalAuthentication {
    private static final String TEST_CATALOG = "test_catalog";
    private TestingTrinoServer server;

    @BeforeClass
    public void setup() throws Exception {
        Logging.initialize();
        this.server = TestingTrinoServer.builder().setAdditionalModule((Module)new DummyExternalAuthModule(() -> this.server.getAddress().getPort())).setProperties((Map)ImmutableMap.builder().put((Object)"http-server.authentication.type", (Object)"dummy-external").put((Object)"http-server.https.enabled", (Object)"true").put((Object)"http-server.https.keystore.path", (Object)new File(Resources.getResource((String)"localhost.keystore").toURI()).getPath()).put((Object)"http-server.https.keystore.key", (Object)"changeit").put((Object)"web-ui.enabled", (Object)"false").buildOrThrow()).build();
        this.server.installPlugin((Plugin)new TpchPlugin());
        this.server.createCatalog(TEST_CATALOG, "tpch");
        this.server.waitForNodeRefresh(Duration.ofSeconds(10L));
    }

    @AfterClass(alwaysRun=true)
    public void teardown() throws Exception {
        Closeables.closeAll((Closeable[])new Closeable[]{this.server});
        this.server = null;
    }

    @BeforeMethod(alwaysRun=true)
    public void clearUpLoggingSessions() {
        this.invalidateAllTokens();
    }

    @Test
    public void testSuccessfulAuthenticationWithHttpGetOnlyRedirectHandler() throws Exception {
        try (RedirectHandlerFixture ignore = RedirectHandlerFixture.withHandler(new HttpGetOnlyRedirectHandler());
             Connection connection = this.createConnection();
             Statement statement = connection.createStatement();){
            Assertions.assertThat((boolean)statement.execute("SELECT 123")).isTrue();
        }
    }

    @Test(enabled=false)
    public void testSuccessfulAuthenticationWithDefaultBrowserRedirect() throws Exception {
        try (Connection connection = this.createConnection();
             Statement statement = connection.createStatement();){
            Assertions.assertThat((boolean)statement.execute("SELECT 123")).isTrue();
        }
    }

    @Test
    public void testAuthenticationFailsAfterUnfinishedRedirect() throws Exception {
        try (RedirectHandlerFixture ignore = RedirectHandlerFixture.withHandler(new NoOpRedirectHandler());
             Connection connection = this.createConnection();
             Statement statement = connection.createStatement();){
            Assertions.assertThatThrownBy(() -> statement.execute("SELECT 123")).isInstanceOf(SQLException.class);
        }
    }

    @Test
    public void testAuthenticationFailsAfterRedirectException() throws Exception {
        try (RedirectHandlerFixture ignore = RedirectHandlerFixture.withHandler(new FailingRedirectHandler());
             Connection connection = this.createConnection();
             Statement statement = connection.createStatement();){
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> statement.execute("SELECT 123")).isInstanceOf(SQLException.class)).hasCauseExactlyInstanceOf(RedirectException.class);
        }
    }

    @Test
    public void testAuthenticationFailsAfterServerAuthenticationFailure() throws Exception {
        try (RedirectHandlerFixture ignore = RedirectHandlerFixture.withHandler(new HttpGetOnlyRedirectHandler());
             AutoCloseable ignore2 = TokenPollingErrorFixture.withPollingError("error occurred during token polling");
             Connection connection = this.createConnection();
             Statement statement = connection.createStatement();){
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> statement.execute("SELECT 123")).isInstanceOf(SQLException.class)).hasMessage("error occurred during token polling");
        }
    }

    @Test
    public void testAuthenticationFailsAfterReceivingMalformedHeaderFromServer() throws Exception {
        try (RedirectHandlerFixture ignore = RedirectHandlerFixture.withHandler(new HttpGetOnlyRedirectHandler());
             AutoCloseable ignored = WwwAuthenticateHeaderFixture.withWwwAuthenticate("Bearer no-valid-fields");
             Connection connection = this.createConnection();
             Statement statement = connection.createStatement();){
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> statement.execute("SELECT 123")).isInstanceOf(SQLException.class)).hasCauseInstanceOf(ClientException.class).hasMessage("Authentication failed: Authentication required");
        }
    }

    @Test
    public void testAuthenticationReusesObtainedTokenPerConnection() throws Exception {
        try (RedirectHandlerFixture ignore = RedirectHandlerFixture.withHandler(new HttpGetOnlyRedirectHandler());
             Connection connection = this.createConnection();
             Statement statement = connection.createStatement();){
            statement.execute("SELECT 123");
            statement.execute("SELECT 123");
            statement.execute("SELECT 123");
            Assertions.assertThat((int)this.countIssuedTokens()).isEqualTo(1);
        }
    }

    @Test
    public void testAuthenticationAfterInitialTokenHasBeenInvalidated() throws Exception {
        try (RedirectHandlerFixture ignore = RedirectHandlerFixture.withHandler(new HttpGetOnlyRedirectHandler());
             Connection connection = this.createConnection();
             Statement statement = connection.createStatement();){
            statement.execute("SELECT 123");
            this.invalidateAllTokens();
            Assertions.assertThat((int)this.countIssuedTokens()).isEqualTo(0);
            Assertions.assertThat((boolean)statement.execute("SELECT 123")).isTrue();
        }
    }

    private Connection createConnection() throws Exception {
        String url = String.format("jdbc:trino://localhost:%s", this.server.getHttpsAddress().getPort());
        Properties properties = new Properties();
        properties.setProperty("SSL", "true");
        properties.setProperty("SSLTrustStorePath", new File(Resources.getResource((String)"localhost.truststore").toURI()).getPath());
        properties.setProperty("SSLTrustStorePassword", "changeit");
        properties.setProperty("externalAuthentication", "true");
        properties.setProperty("externalAuthenticationTimeout", "2s");
        return DriverManager.getConnection(url, properties);
    }

    private void invalidateAllTokens() {
        Authentications authentications = (Authentications)this.server.getInstance(Key.get(Authentications.class));
        authentications.invalidateAllTokens();
    }

    private int countIssuedTokens() {
        Authentications authentications = (Authentications)this.server.getInstance(Key.get(Authentications.class));
        return authentications.countValidTokens();
    }

    private static class DummyExternalAuthModule
    extends AbstractConfigurationAwareModule {
        private final IntSupplier port;

        public DummyExternalAuthModule(IntSupplier port) {
            this.port = Objects.requireNonNull(port, "port is null");
        }

        protected void setup(Binder ignored) {
            this.install(ServerSecurityModule.authenticatorModule((String)"dummy-external", DummyAuthenticator.class, binder -> {
                binder.bind(Authentications.class).in(Scopes.SINGLETON);
                binder.bind(IntSupplier.class).toInstance((Object)this.port);
                JaxrsBinder.jaxrsBinder((Binder)binder).bind(DummyExternalAuthResources.class);
            }));
        }
    }

    public static class HttpGetOnlyRedirectHandler
    implements RedirectHandler {
        public void redirectTo(URI uri) throws RedirectException {
            OkHttpClient client = new OkHttpClient();
            Request request = new Request.Builder().url(HttpUrl.get((String)uri.toString())).build();
            try (okhttp3.Response response = client.newCall(request).execute();){
                if (response.code() != 200) {
                    throw new RedirectException("HTTP GET failed with status " + response.code());
                }
            }
            catch (IOException e) {
                throw new RedirectException("Redirection failed", (Throwable)e);
            }
        }
    }

    static class RedirectHandlerFixture
    implements AutoCloseable {
        private static final RedirectHandlerFixture INSTANCE = new RedirectHandlerFixture();

        private RedirectHandlerFixture() {
        }

        public static RedirectHandlerFixture withHandler(RedirectHandler handler) {
            TrinoDriverUri.setRedirectHandler((RedirectHandler)handler);
            return INSTANCE;
        }

        @Override
        public void close() {
            TrinoDriverUri.setRedirectHandler((RedirectHandler)new DesktopBrowserRedirectHandler());
        }
    }

    public static class NoOpRedirectHandler
    implements RedirectHandler {
        public void redirectTo(URI uri) throws RedirectException {
        }
    }

    public static class FailingRedirectHandler
    implements RedirectHandler {
        public void redirectTo(URI uri) throws RedirectException {
            throw new RedirectException("Redirect to uri has failed " + uri);
        }
    }

    static class TokenPollingErrorFixture
    implements AutoCloseable {
        private static final AtomicReference<String> ERROR = new AtomicReference<Object>(null);

        TokenPollingErrorFixture() {
        }

        public static AutoCloseable withPollingError(String error) {
            if (ERROR.compareAndSet(null, error)) {
                return new TokenPollingErrorFixture();
            }
            throw new ConcurrentModificationException("polling errors can't be invoked in parallel");
        }

        @Override
        public void close() {
            ERROR.set(null);
        }
    }

    static class WwwAuthenticateHeaderFixture
    implements AutoCloseable {
        private static final AtomicReference<String> HEADER = new AtomicReference<Object>(null);

        WwwAuthenticateHeaderFixture() {
        }

        public static AutoCloseable withWwwAuthenticate(String header) {
            if (HEADER.compareAndSet(null, header)) {
                return new WwwAuthenticateHeaderFixture();
            }
            throw new ConcurrentModificationException("with WWW-Authenticate header can't be invoked in parallel");
        }

        @Override
        public void close() {
            HEADER.set(null);
        }
    }

    private static class Authentications {
        private final Map<String, String> logginSessions = new ConcurrentHashMap<String, String>();
        private final Set<String> validTokens = ConcurrentHashMap.newKeySet();

        private Authentications() {
        }

        public String startAuthentication() {
            String sessionId = UUID.randomUUID().toString();
            this.logginSessions.put(sessionId, "");
            return sessionId;
        }

        public void logIn(String sessionId) {
            String token = sessionId + "_token";
            this.validTokens.add(token);
            this.logginSessions.put(sessionId, token);
        }

        public Optional<String> getToken(String sessionId) throws IllegalArgumentException {
            return Optional.ofNullable(this.logginSessions.get(sessionId)).filter(s -> !s.isEmpty());
        }

        public boolean verifyToken(String token) {
            return this.validTokens.contains(token);
        }

        public void invalidateAllTokens() {
            this.validTokens.clear();
        }

        public int countValidTokens() {
            return this.validTokens.size();
        }
    }

    @Path(value="/v1/authentications/dummy")
    public static class DummyExternalAuthResources {
        private final Authentications authentications;

        @Inject
        public DummyExternalAuthResources(Authentications authentications) {
            this.authentications = authentications;
        }

        @GET
        @Produces(value={"text/plain"})
        @ResourceSecurity(value=ResourceSecurity.AccessType.PUBLIC)
        @Path(value="logins/{sessionId}")
        public String logInUser(@PathParam(value="sessionId") String sessionId) {
            this.authentications.logIn(sessionId);
            return "User has been successfully logged in during " + sessionId + " session";
        }

        @GET
        @ResourceSecurity(value=ResourceSecurity.AccessType.PUBLIC)
        @Path(value="{sessionId}")
        public Response getToken(@PathParam(value="sessionId") String sessionId, @Context HttpServletRequest request) {
            try {
                return Optional.ofNullable((String)TokenPollingErrorFixture.ERROR.get()).map(error -> Response.ok((Object)String.format("{ \"error\" : \"%s\"}", error), (MediaType)MediaType.APPLICATION_JSON_TYPE).build()).orElseGet(() -> this.authentications.getToken(sessionId).map(token -> Response.ok((Object)String.format("{ \"token\" : \"%s\"}", token), (MediaType)MediaType.APPLICATION_JSON_TYPE).build()).orElseGet(() -> Response.ok((Object)String.format("{ \"nextUri\" : \"%s\" }", request.getRequestURI()), (MediaType)MediaType.APPLICATION_JSON_TYPE).build()));
            }
            catch (IllegalArgumentException ex) {
                return Response.status((Response.Status)Response.Status.NOT_FOUND).build();
            }
        }
    }

    public static class DummyAuthenticator
    implements Authenticator {
        private final IntSupplier port;
        private final Authentications authentications;

        @Inject
        public DummyAuthenticator(IntSupplier port, Authentications authentications) {
            this.port = Objects.requireNonNull(port, "port is null");
            this.authentications = Objects.requireNonNull(authentications, "authentications is null");
        }

        public Identity authenticate(ContainerRequestContext request) throws AuthenticationException {
            List bearerHeaders = (List)request.getHeaders().getOrDefault((Object)"Authorization", (Object)ImmutableList.of());
            if (bearerHeaders.stream().filter(header -> header.startsWith("Bearer ")).anyMatch(header -> this.authentications.verifyToken(header.substring("Bearer ".length())))) {
                return Identity.ofUser((String)"user");
            }
            String sessionId = this.authentications.startAuthentication();
            throw Optional.ofNullable((String)WwwAuthenticateHeaderFixture.HEADER.get()).map(header -> new AuthenticationException("Authentication required", header)).orElseGet(() -> new AuthenticationException("Authentication required", String.format("Bearer x_redirect_server=\"http://localhost:%s/v1/authentications/dummy/logins/%s\", x_token_server=\"http://localhost:%s/v1/authentications/dummy/%s\"", this.port.getAsInt(), sessionId, this.port.getAsInt(), sessionId)));
        }
    }
}

