/*
 * Decompiled with CFR 0.152.
 */
package io.trino.security;

import com.google.common.collect.ImmutableSet;
import com.google.inject.Binder;
import com.google.inject.multibindings.OptionalBinder;
import io.airlift.log.Logging;
import io.trino.jdbc.BaseTrinoDriverTest;
import io.trino.jdbc.TrinoConnection;
import io.trino.metadata.SystemSecurityMetadata;
import io.trino.plugin.memory.MemoryPlugin;
import io.trino.security.TestingSystemSecurityMetadata;
import io.trino.server.testing.TestingTrinoServer;
import io.trino.spi.Plugin;
import io.trino.spi.security.PrincipalType;
import io.trino.spi.security.TrinoPrincipal;
import io.trino.testing.TestingAccessControlManager;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

@TestInstance(value=TestInstance.Lifecycle.PER_CLASS)
@Execution(value=ExecutionMode.SAME_THREAD)
public class TestImpersonation {
    private TestingTrinoServer server;
    private final TestingSystemSecurityMetadata securityMetadata = new TestingSystemSecurityMetadata();
    private TestingAccessControlManager accessControl;

    @BeforeAll
    public void setup() throws Exception {
        Logging.initialize();
        this.server = TestingTrinoServer.builder().setAdditionalModule(binder -> OptionalBinder.newOptionalBinder((Binder)binder, SystemSecurityMetadata.class).setBinding().toInstance((Object)this.securityMetadata)).build();
        this.server.installPlugin((Plugin)new MemoryPlugin());
        this.server.createCatalog("memory", "memory");
        this.accessControl = this.server.getAccessControl();
    }

    @ParameterizedTest
    @MethodSource(value={"roles"})
    @Timeout(value=10L)
    public void testImpersonationAllowedByRole(String roleName) throws Exception {
        this.securityMetadata.reset();
        this.accessControl.reset();
        try (TrinoConnection connection = this.createConnection("memory", "default", "alice").unwrap(TrinoConnection.class);
             Statement statement = connection.createStatement();){
            Assertions.assertThat((String)BaseTrinoDriverTest.getCurrentUser((Connection)connection)).isEqualTo("alice");
            this.securityMetadata.createRole(null, "invalid_role", Optional.empty());
            this.securityMetadata.grantRoles(null, (Set<String>)ImmutableSet.of((Object)"invalid_role"), (Set<TrinoPrincipal>)ImmutableSet.of((Object)new TrinoPrincipal(PrincipalType.USER, "alice")), false, Optional.empty());
            this.denyImpersonation();
            statement.execute("SET ROLE invalid_role");
            Assertions.assertThatThrownBy(() -> statement.execute("SET SESSION AUTHORIZATION john")).hasMessageContaining("User alice cannot impersonate user john");
            this.securityMetadata.createRole(null, "alice_role", Optional.empty());
            this.securityMetadata.grantRoles(null, (Set<String>)ImmutableSet.of((Object)"alice_role"), (Set<TrinoPrincipal>)ImmutableSet.of((Object)new TrinoPrincipal(PrincipalType.USER, "alice")), false, Optional.empty());
            Assertions.assertThatThrownBy(() -> statement.execute("SET SESSION AUTHORIZATION john")).hasMessageContaining("User alice cannot impersonate user john");
            statement.execute("SET ROLE " + roleName);
            statement.execute("SET SESSION AUTHORIZATION john");
            statement.execute("SHOW SCHEMAS IN memory");
            statement.execute("SHOW SCHEMAS IN memory");
        }
    }

    @Test
    @Timeout(value=10L)
    public void testImpersonationDisallowedWhenRoleIsNone() throws Exception {
        this.securityMetadata.reset();
        this.accessControl.reset();
        try (TrinoConnection connection = this.createConnection("memory", "default", "alice").unwrap(TrinoConnection.class);
             Statement statement = connection.createStatement();){
            Assertions.assertThat((String)BaseTrinoDriverTest.getCurrentUser((Connection)connection)).isEqualTo("alice");
            this.securityMetadata.createRole(null, "alice_role", Optional.empty());
            this.denyImpersonation();
            this.securityMetadata.grantRoles(null, (Set<String>)ImmutableSet.of((Object)"alice_role"), (Set<TrinoPrincipal>)ImmutableSet.of((Object)new TrinoPrincipal(PrincipalType.USER, "alice")), false, Optional.empty());
            statement.execute("SET ROLE NONE");
            Assertions.assertThatThrownBy(() -> statement.execute("SET SESSION AUTHORIZATION john")).hasMessageContaining("User alice cannot impersonate user john");
        }
    }

    private Connection createConnection(String catalog, String schema, String user) throws SQLException {
        String url = String.format("jdbc:trino://%s/%s/%s", this.server.getAddress(), catalog, schema);
        return DriverManager.getConnection(url, user, null);
    }

    private Stream<String> roles() {
        return Stream.of("alice_role", "ALL");
    }

    private void denyImpersonation() {
        this.accessControl.denyImpersonation((identity, string) -> identity.getEnabledRoles().stream().anyMatch(role -> role.equalsIgnoreCase("alice_role")));
    }
}

