/*
 * Decompiled with CFR 0.152.
 */
package io.trino.server;

import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import io.airlift.jaxrs.testing.GuavaMultivaluedMap;
import io.trino.client.ProtocolHeaders;
import io.trino.metadata.Metadata;
import io.trino.metadata.MetadataManager;
import io.trino.security.AccessControl;
import io.trino.security.AllowAllAccessControl;
import io.trino.server.HttpRequestSessionContextFactory;
import io.trino.server.ProtocolConfig;
import io.trino.server.SessionContext;
import io.trino.server.protocol.PreparedStatementEncoder;
import io.trino.spi.security.Identity;
import io.trino.spi.security.SelectedRole;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.MultivaluedHashMap;
import jakarta.ws.rs.core.MultivaluedMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestHttpRequestSessionContextFactory {
    private static final HttpRequestSessionContextFactory SESSION_CONTEXT_FACTORY = new HttpRequestSessionContextFactory(new PreparedStatementEncoder(new ProtocolConfig()), (Metadata)MetadataManager.createTestMetadataManager(), ImmutableSet::of, (AccessControl)new AllowAllAccessControl());

    @Test
    public void testSessionContext() {
        TestHttpRequestSessionContextFactory.assertSessionContext(ProtocolHeaders.TRINO_HEADERS);
        TestHttpRequestSessionContextFactory.assertSessionContext(ProtocolHeaders.createProtocolHeaders((String)"taco"));
    }

    private static void assertSessionContext(ProtocolHeaders protocolHeaders) {
        GuavaMultivaluedMap headers = new GuavaMultivaluedMap((Multimap)ImmutableListMultimap.builder().put((Object)protocolHeaders.requestUser(), (Object)"testUser").put((Object)protocolHeaders.requestSource(), (Object)"testSource").put((Object)protocolHeaders.requestCatalog(), (Object)"testCatalog").put((Object)protocolHeaders.requestSchema(), (Object)"testSchema").put((Object)protocolHeaders.requestPath(), (Object)"testPath").put((Object)protocolHeaders.requestLanguage(), (Object)"zh-TW").put((Object)protocolHeaders.requestTimeZone(), (Object)"Asia/Taipei").put((Object)protocolHeaders.requestClientInfo(), (Object)"client-info").put((Object)protocolHeaders.requestSession(), (Object)"query_max_memory=1GB").put((Object)protocolHeaders.requestSession(), (Object)"join_distribution_type=partitioned,max_hash_partition_count = 43").put((Object)protocolHeaders.requestSession(), (Object)"some_session_property=some value with %2C comma").put((Object)protocolHeaders.requestPreparedStatement(), (Object)"query1=select * from foo,query2=select * from bar").put((Object)protocolHeaders.requestRole(), (Object)"system=ROLE{system-role}").put((Object)protocolHeaders.requestRole(), (Object)"foo_connector=ALL").put((Object)protocolHeaders.requestRole(), (Object)"bar_connector=NONE").put((Object)protocolHeaders.requestRole(), (Object)"foobar_connector=ROLE{catalog-role}").put((Object)protocolHeaders.requestExtraCredential(), (Object)"test.token.foo=bar").put((Object)protocolHeaders.requestExtraCredential(), (Object)"test.token.abc=xyz").build());
        SessionContext context = SESSION_CONTEXT_FACTORY.createSessionContext((MultivaluedMap)headers, Optional.of(protocolHeaders.getProtocolName()), Optional.of("testRemote"), Optional.empty());
        Assertions.assertThat((String)context.getSource().orElse(null)).isEqualTo("testSource");
        Assertions.assertThat((String)context.getCatalog().orElse(null)).isEqualTo("testCatalog");
        Assertions.assertThat((String)context.getSchema().orElse(null)).isEqualTo("testSchema");
        Assertions.assertThat((String)context.getPath().orElse(null)).isEqualTo("testPath");
        Assertions.assertThat((Object)context.getIdentity()).isEqualTo((Object)Identity.forUser((String)"testUser").withGroups((Set)ImmutableSet.of((Object)"testUser")).withConnectorRoles((Map)ImmutableMap.of((Object)"foo_connector", (Object)new SelectedRole(SelectedRole.Type.ALL, Optional.empty()), (Object)"bar_connector", (Object)new SelectedRole(SelectedRole.Type.NONE, Optional.empty()), (Object)"foobar_connector", (Object)new SelectedRole(SelectedRole.Type.ROLE, Optional.of("catalog-role")))).withEnabledRoles((Set)ImmutableSet.of((Object)"system-role")).build());
        Assertions.assertThat((String)context.getClientInfo().orElse(null)).isEqualTo("client-info");
        Assertions.assertThat((String)context.getLanguage().orElse(null)).isEqualTo("zh-TW");
        Assertions.assertThat((String)context.getTimeZoneId().orElse(null)).isEqualTo("Asia/Taipei");
        Assertions.assertThat((Map)context.getSystemProperties()).isEqualTo((Object)ImmutableMap.of((Object)"query_max_memory", (Object)"1GB", (Object)"join_distribution_type", (Object)"partitioned", (Object)"max_hash_partition_count", (Object)"43", (Object)"some_session_property", (Object)"some value with , comma"));
        Assertions.assertThat((Map)context.getPreparedStatements()).isEqualTo((Object)ImmutableMap.of((Object)"query1", (Object)"select * from foo", (Object)"query2", (Object)"select * from bar"));
        Assertions.assertThat((Object)context.getSelectedRole()).isEqualTo((Object)new SelectedRole(SelectedRole.Type.ROLE, Optional.of("system-role")));
        Assertions.assertThat((Map)context.getIdentity().getExtraCredentials()).isEqualTo((Object)ImmutableMap.of((Object)"test.token.foo", (Object)"bar", (Object)"test.token.abc", (Object)"xyz"));
    }

    @Test
    public void testMappedUser() {
        TestHttpRequestSessionContextFactory.assertMappedUser(ProtocolHeaders.TRINO_HEADERS);
        TestHttpRequestSessionContextFactory.assertMappedUser(ProtocolHeaders.createProtocolHeaders((String)"taco"));
    }

    private static void assertMappedUser(ProtocolHeaders protocolHeaders) {
        GuavaMultivaluedMap userHeaders = new GuavaMultivaluedMap((Multimap)ImmutableListMultimap.of((Object)protocolHeaders.requestUser(), (Object)"testUser"));
        MultivaluedHashMap emptyHeaders = new MultivaluedHashMap();
        SessionContext context = SESSION_CONTEXT_FACTORY.createSessionContext((MultivaluedMap)userHeaders, Optional.of(protocolHeaders.getProtocolName()), Optional.of("testRemote"), Optional.empty());
        Assertions.assertThat((Object)context.getIdentity()).isEqualTo((Object)Identity.forUser((String)"testUser").withGroups((Set)ImmutableSet.of((Object)"testUser")).build());
        context = SESSION_CONTEXT_FACTORY.createSessionContext((MultivaluedMap)emptyHeaders, Optional.of(protocolHeaders.getProtocolName()), Optional.of("testRemote"), Optional.of(Identity.forUser((String)"mappedUser").withGroups((Set)ImmutableSet.of((Object)"test")).build()));
        Assertions.assertThat((Object)context.getIdentity()).isEqualTo((Object)Identity.forUser((String)"mappedUser").withGroups((Set)ImmutableSet.of((Object)"test", (Object)"mappedUser")).build());
        context = SESSION_CONTEXT_FACTORY.createSessionContext((MultivaluedMap)userHeaders, Optional.of(protocolHeaders.getProtocolName()), Optional.of("testRemote"), Optional.of(Identity.ofUser((String)"mappedUser")));
        Assertions.assertThat((Object)context.getIdentity()).isEqualTo((Object)Identity.forUser((String)"testUser").withGroups((Set)ImmutableSet.of((Object)"testUser")).build());
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> TestHttpRequestSessionContextFactory.lambda$assertMappedUser$0((MultivaluedMap)emptyHeaders, protocolHeaders)).isInstanceOf(WebApplicationException.class)).matches(e -> ((WebApplicationException)e).getResponse().getStatus() == 400);
    }

    @Test
    public void testPreparedStatementsHeaderDoesNotParse() {
        TestHttpRequestSessionContextFactory.assertPreparedStatementsHeaderDoesNotParse(ProtocolHeaders.TRINO_HEADERS);
        TestHttpRequestSessionContextFactory.assertPreparedStatementsHeaderDoesNotParse(ProtocolHeaders.createProtocolHeaders((String)"taco"));
    }

    private static void assertPreparedStatementsHeaderDoesNotParse(ProtocolHeaders protocolHeaders) {
        GuavaMultivaluedMap headers = new GuavaMultivaluedMap((Multimap)ImmutableListMultimap.builder().put((Object)protocolHeaders.requestUser(), (Object)"testUser").put((Object)protocolHeaders.requestSource(), (Object)"testSource").put((Object)protocolHeaders.requestCatalog(), (Object)"testCatalog").put((Object)protocolHeaders.requestSchema(), (Object)"testSchema").put((Object)protocolHeaders.requestPath(), (Object)"testPath").put((Object)protocolHeaders.requestLanguage(), (Object)"zh-TW").put((Object)protocolHeaders.requestTimeZone(), (Object)"Asia/Taipei").put((Object)protocolHeaders.requestClientInfo(), (Object)"null").put((Object)protocolHeaders.requestPreparedStatement(), (Object)"query1=abcdefg").build());
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> TestHttpRequestSessionContextFactory.lambda$assertPreparedStatementsHeaderDoesNotParse$2((MultivaluedMap)headers, protocolHeaders)).isInstanceOf(WebApplicationException.class)).hasMessageMatching("Invalid " + protocolHeaders.requestPreparedStatement() + " header: line 1:1: mismatched input 'abcdefg'. Expecting: .*");
    }

    private static /* synthetic */ void lambda$assertPreparedStatementsHeaderDoesNotParse$2(MultivaluedMap headers, ProtocolHeaders protocolHeaders) throws Throwable {
        SESSION_CONTEXT_FACTORY.createSessionContext(headers, Optional.of(protocolHeaders.getProtocolName()), Optional.of("testRemote"), Optional.empty());
    }

    private static /* synthetic */ void lambda$assertMappedUser$0(MultivaluedMap emptyHeaders, ProtocolHeaders protocolHeaders) throws Throwable {
        SESSION_CONTEXT_FACTORY.createSessionContext(emptyHeaders, Optional.of(protocolHeaders.getProtocolName()), Optional.of("testRemote"), Optional.empty());
    }
}

