/*
 * Decompiled with CFR 0.152.
 */
package io.trino.server;

import com.google.common.base.MoreObjects;
import com.google.common.base.Splitter;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.errorprone.annotations.FormatMethod;
import com.google.inject.Inject;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.trino.Session;
import io.trino.client.ProtocolDetectionException;
import io.trino.client.ProtocolHeaders;
import io.trino.metadata.Metadata;
import io.trino.security.AccessControl;
import io.trino.server.SessionContext;
import io.trino.server.protocol.PreparedStatementEncoder;
import io.trino.spi.security.AccessDeniedException;
import io.trino.spi.security.GroupProvider;
import io.trino.spi.security.Identity;
import io.trino.spi.security.SelectedRole;
import io.trino.spi.session.ResourceEstimates;
import io.trino.sql.parser.ParsingException;
import io.trino.sql.parser.SqlParser;
import io.trino.transaction.TransactionId;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.Response;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class HttpRequestSessionContextFactory {
    private static final Splitter DOT_SPLITTER = Splitter.on((char)'.');
    public static final String AUTHENTICATED_IDENTITY = "trino.authenticated-identity";
    private final PreparedStatementEncoder preparedStatementEncoder;
    private final Metadata metadata;
    private final GroupProvider groupProvider;
    private final AccessControl accessControl;

    @Inject
    public HttpRequestSessionContextFactory(PreparedStatementEncoder preparedStatementEncoder, Metadata metadata, GroupProvider groupProvider, AccessControl accessControl) {
        this.preparedStatementEncoder = Objects.requireNonNull(preparedStatementEncoder, "preparedStatementEncoder is null");
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        this.groupProvider = Objects.requireNonNull(groupProvider, "groupProvider is null");
        this.accessControl = Objects.requireNonNull(accessControl, "accessControl is null");
    }

    public SessionContext createSessionContext(MultivaluedMap<String, String> headers, Optional<String> alternateHeaderName, Optional<String> remoteAddress, Optional<Identity> authenticatedIdentity) throws WebApplicationException {
        ProtocolHeaders protocolHeaders;
        try {
            protocolHeaders = ProtocolHeaders.detectProtocol(alternateHeaderName, (Set)headers.keySet());
        }
        catch (ProtocolDetectionException e) {
            throw HttpRequestSessionContextFactory.badRequest(e.getMessage());
        }
        Optional<String> catalog = Optional.ofNullable(HttpRequestSessionContextFactory.trimEmptyToNull((String)headers.getFirst((Object)protocolHeaders.requestCatalog())));
        Optional<String> schema = Optional.ofNullable(HttpRequestSessionContextFactory.trimEmptyToNull((String)headers.getFirst((Object)protocolHeaders.requestSchema())));
        Optional<String> path = Optional.ofNullable(HttpRequestSessionContextFactory.trimEmptyToNull((String)headers.getFirst((Object)protocolHeaders.requestPath())));
        HttpRequestSessionContextFactory.assertRequest(catalog.isPresent() || schema.isEmpty(), "Schema is set but catalog is not", new Object[0]);
        Objects.requireNonNull(authenticatedIdentity, "authenticatedIdentity is null");
        Identity identity = this.buildSessionIdentity(authenticatedIdentity, protocolHeaders, headers);
        Identity originalIdentity = this.buildSessionOriginalIdentity(identity, protocolHeaders, headers);
        SelectedRole selectedRole = HttpRequestSessionContextFactory.parseSystemRoleHeaders(protocolHeaders, headers);
        Optional<String> source = Optional.ofNullable((String)headers.getFirst((Object)protocolHeaders.requestSource()));
        Optional<String> traceToken = Optional.ofNullable(HttpRequestSessionContextFactory.trimEmptyToNull((String)headers.getFirst((Object)protocolHeaders.requestTraceToken())));
        Optional<String> userAgent = Optional.ofNullable((String)headers.getFirst((Object)"User-Agent"));
        Optional<String> remoteUserAddress = Objects.requireNonNull(remoteAddress, "remoteAddress is null");
        Optional<String> timeZoneId = Optional.ofNullable((String)headers.getFirst((Object)protocolHeaders.requestTimeZone()));
        Optional<String> language = Optional.ofNullable((String)headers.getFirst((Object)protocolHeaders.requestLanguage()));
        Optional<String> clientInfo = Optional.ofNullable((String)headers.getFirst((Object)protocolHeaders.requestClientInfo()));
        Set<String> clientTags = HttpRequestSessionContextFactory.parseClientTags(protocolHeaders, headers);
        Set<String> clientCapabilities = HttpRequestSessionContextFactory.parseClientCapabilities(protocolHeaders, headers);
        ResourceEstimates resourceEstimates = HttpRequestSessionContextFactory.parseResourceEstimate(protocolHeaders, headers);
        ImmutableMap.Builder systemProperties = ImmutableMap.builder();
        HashMap<String, Map<String, String>> catalogSessionProperties = new HashMap();
        for (Map.Entry<String, String> entry2 : HttpRequestSessionContextFactory.parseSessionHeaders(protocolHeaders, headers).entrySet()) {
            String fullPropertyName = entry2.getKey();
            String propertyValue = entry2.getValue();
            List nameParts = DOT_SPLITTER.splitToList((CharSequence)fullPropertyName);
            if (nameParts.size() == 1) {
                String propertyName = (String)nameParts.get(0);
                HttpRequestSessionContextFactory.assertRequest(!propertyName.isEmpty(), "Invalid %s header", protocolHeaders.requestSession());
                systemProperties.put((Object)propertyName, (Object)propertyValue);
                continue;
            }
            if (nameParts.size() == 2) {
                String catalogName = (String)nameParts.get(0);
                String propertyName = (String)nameParts.get(1);
                HttpRequestSessionContextFactory.assertRequest(!catalogName.isEmpty(), "Invalid %s header", protocolHeaders.requestSession());
                HttpRequestSessionContextFactory.assertRequest(!propertyName.isEmpty(), "Invalid %s header", protocolHeaders.requestSession());
                catalogSessionProperties.computeIfAbsent(catalogName, id -> new HashMap()).put(propertyName, propertyValue);
                continue;
            }
            throw HttpRequestSessionContextFactory.badRequest(String.format("Invalid %s header", protocolHeaders.requestSession()));
        }
        Objects.requireNonNull(catalogSessionProperties, "catalogSessionProperties is null");
        catalogSessionProperties = (Map)catalogSessionProperties.entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, entry -> ImmutableMap.copyOf((Map)((Map)entry.getValue()))));
        Map<String, String> preparedStatements = this.parsePreparedStatementsHeaders(protocolHeaders, headers);
        String transactionIdHeader = (String)headers.getFirst((Object)protocolHeaders.requestTransactionId());
        boolean clientTransactionSupport = transactionIdHeader != null;
        Optional<TransactionId> transactionId = HttpRequestSessionContextFactory.parseTransactionId(transactionIdHeader);
        return new SessionContext(protocolHeaders, catalog, schema, path, authenticatedIdentity, identity, originalIdentity, selectedRole, source, traceToken, userAgent, remoteUserAddress, timeZoneId, language, clientTags, clientCapabilities, resourceEstimates, (Map<String, String>)systemProperties.buildOrThrow(), catalogSessionProperties, preparedStatements, transactionId, clientTransactionSupport, clientInfo);
    }

    public Identity extractAuthorizedIdentity(HttpServletRequest servletRequest, HttpHeaders httpHeaders, Optional<String> alternateHeaderName) {
        return this.extractAuthorizedIdentity(Optional.ofNullable((Identity)servletRequest.getAttribute(AUTHENTICATED_IDENTITY)), (MultivaluedMap<String, String>)httpHeaders.getRequestHeaders(), alternateHeaderName);
    }

    public Identity extractAuthorizedIdentity(Optional<Identity> optionalAuthenticatedIdentity, MultivaluedMap<String, String> headers, Optional<String> alternateHeaderName) throws AccessDeniedException {
        ProtocolHeaders protocolHeaders;
        try {
            protocolHeaders = ProtocolHeaders.detectProtocol(alternateHeaderName, (Set)headers.keySet());
        }
        catch (ProtocolDetectionException e) {
            throw HttpRequestSessionContextFactory.badRequest(e.getMessage());
        }
        Identity identity = this.buildSessionIdentity(optionalAuthenticatedIdentity, protocolHeaders, headers);
        Identity originalIdentity = this.buildSessionOriginalIdentity(identity, protocolHeaders, headers);
        this.accessControl.checkCanSetUser(originalIdentity.getPrincipal(), originalIdentity.getUser());
        optionalAuthenticatedIdentity.ifPresent(authenticatedIdentity -> {
            if (!authenticatedIdentity.getUser().equals(originalIdentity.getUser())) {
                authenticatedIdentity = Identity.from((Identity)authenticatedIdentity).withEnabledRoles(this.metadata.listEnabledRoles((Identity)authenticatedIdentity)).build();
                this.accessControl.checkCanImpersonateUser((Identity)authenticatedIdentity, originalIdentity.getUser());
            }
        });
        if (!originalIdentity.getUser().equals(identity.getUser())) {
            this.accessControl.checkCanSetUser(originalIdentity.getPrincipal(), identity.getUser());
            this.accessControl.checkCanImpersonateUser(originalIdentity, identity.getUser());
        }
        return HttpRequestSessionContextFactory.addEnabledRoles(identity, HttpRequestSessionContextFactory.parseSystemRoleHeaders(protocolHeaders, headers), this.metadata);
    }

    public static Identity addEnabledRoles(Identity identity, SelectedRole selectedRole, Metadata metadata) {
        if (selectedRole.getType() == SelectedRole.Type.NONE) {
            return identity;
        }
        ImmutableSet enabledRoles = metadata.listEnabledRoles(identity);
        if (selectedRole.getType() == SelectedRole.Type.ROLE) {
            String role = (String)selectedRole.getRole().orElseThrow();
            if (!enabledRoles.contains(role)) {
                AccessDeniedException.denySetRole((String)role);
            }
            enabledRoles = ImmutableSet.of((Object)role);
        }
        return Identity.from((Identity)identity).withEnabledRoles(enabledRoles).build();
    }

    private Identity buildSessionIdentity(Optional<Identity> authenticatedIdentity, ProtocolHeaders protocolHeaders, MultivaluedMap<String, String> headers) {
        String trinoUser = HttpRequestSessionContextFactory.trimEmptyToNull((String)headers.getFirst((Object)protocolHeaders.requestUser()));
        String user = trinoUser != null ? trinoUser : (String)authenticatedIdentity.map(Identity::getUser).orElse(null);
        HttpRequestSessionContextFactory.assertRequest(user != null, "User must be set", new Object[0]);
        SelectedRole systemRole = HttpRequestSessionContextFactory.parseSystemRoleHeaders(protocolHeaders, headers);
        ImmutableSet.Builder systemEnabledRoles = ImmutableSet.builder();
        if (systemRole.getType() == SelectedRole.Type.ROLE) {
            systemEnabledRoles.add((Object)((String)systemRole.getRole().orElseThrow()));
        }
        return authenticatedIdentity.map(identity -> Identity.from((Identity)identity).withUser(user)).orElseGet(() -> Identity.forUser((String)user)).withEnabledRoles((Set)systemEnabledRoles.build()).withAdditionalConnectorRoles(HttpRequestSessionContextFactory.parseConnectorRoleHeaders(protocolHeaders, headers)).withAdditionalExtraCredentials(HttpRequestSessionContextFactory.parseExtraCredentials(protocolHeaders, headers)).withAdditionalGroups(this.groupProvider.getGroups(user)).build();
    }

    private Identity buildSessionOriginalIdentity(Identity identity, ProtocolHeaders protocolHeaders, MultivaluedMap<String, String> headers) {
        Optional<String> optionalOriginalUser = Optional.ofNullable(HttpRequestSessionContextFactory.trimEmptyToNull((String)headers.getFirst((Object)protocolHeaders.requestOriginalUser())));
        Identity originalIdentity = optionalOriginalUser.map(originalUser -> Identity.from((Identity)identity).withUser(originalUser).withExtraCredentials(new HashMap()).withGroups(this.groupProvider.getGroups(originalUser)).build()).orElse(identity);
        return originalIdentity;
    }

    private static List<String> splitHttpHeader(MultivaluedMap<String, String> headers, String name) {
        List values = (List)MoreObjects.firstNonNull((Object)((List)headers.get((Object)name)), (Object)ImmutableList.of());
        Splitter splitter = Splitter.on((char)',').trimResults().omitEmptyStrings();
        return (List)values.stream().map(arg_0 -> ((Splitter)splitter).splitToList(arg_0)).flatMap(Collection::stream).collect(ImmutableList.toImmutableList());
    }

    private static Map<String, String> parseSessionHeaders(ProtocolHeaders protocolHeaders, MultivaluedMap<String, String> headers) {
        return HttpRequestSessionContextFactory.parseProperty(headers, protocolHeaders.requestSession());
    }

    private static SelectedRole parseSystemRoleHeaders(ProtocolHeaders protocolHeaders, MultivaluedMap<String, String> headers) {
        return HttpRequestSessionContextFactory.parseProperty(headers, protocolHeaders.requestRole()).entrySet().stream().filter(entry -> ((String)entry.getKey()).equalsIgnoreCase("system")).map(Map.Entry::getValue).map(role -> HttpRequestSessionContextFactory.toSelectedRole(protocolHeaders, role)).findFirst().orElse(new SelectedRole(SelectedRole.Type.ALL, Optional.empty()));
    }

    private static Map<String, SelectedRole> parseConnectorRoleHeaders(ProtocolHeaders protocolHeaders, MultivaluedMap<String, String> headers) {
        ImmutableMap.Builder roles = ImmutableMap.builder();
        HttpRequestSessionContextFactory.parseProperty(headers, protocolHeaders.requestRole()).forEach((key, value) -> {
            if (key.equalsIgnoreCase("system")) {
                return;
            }
            roles.put(key, (Object)HttpRequestSessionContextFactory.toSelectedRole(protocolHeaders, value));
        });
        return roles.buildOrThrow();
    }

    private static SelectedRole toSelectedRole(ProtocolHeaders protocolHeaders, String value) {
        SelectedRole role;
        try {
            role = SelectedRole.valueOf((String)value);
        }
        catch (IllegalArgumentException e) {
            throw HttpRequestSessionContextFactory.badRequest(String.format("Invalid %s header", protocolHeaders.requestRole()));
        }
        return role;
    }

    private static Map<String, String> parseExtraCredentials(ProtocolHeaders protocolHeaders, MultivaluedMap<String, String> headers) {
        return HttpRequestSessionContextFactory.parseProperty(headers, protocolHeaders.requestExtraCredential());
    }

    private static Map<String, String> parseProperty(MultivaluedMap<String, String> headers, String headerName) {
        HashMap<String, String> properties = new HashMap<String, String>();
        for (String header : HttpRequestSessionContextFactory.splitHttpHeader(headers, headerName)) {
            List nameValue = Splitter.on((char)'=').trimResults().splitToList((CharSequence)header);
            HttpRequestSessionContextFactory.assertRequest(nameValue.size() == 2, "Invalid %s header", headerName);
            try {
                properties.put((String)nameValue.get(0), HttpRequestSessionContextFactory.urlDecode((String)nameValue.get(1)));
            }
            catch (IllegalArgumentException e) {
                throw HttpRequestSessionContextFactory.badRequest(String.format("Invalid %s header: %s", headerName, e));
            }
        }
        return properties;
    }

    private static Set<String> parseClientTags(ProtocolHeaders protocolHeaders, MultivaluedMap<String, String> headers) {
        Splitter splitter = Splitter.on((char)',').trimResults().omitEmptyStrings();
        return ImmutableSet.copyOf((Iterable)splitter.split((CharSequence)Strings.nullToEmpty((String)((String)headers.getFirst((Object)protocolHeaders.requestClientTags())))));
    }

    private static Set<String> parseClientCapabilities(ProtocolHeaders protocolHeaders, MultivaluedMap<String, String> headers) {
        Splitter splitter = Splitter.on((char)',').trimResults().omitEmptyStrings();
        return ImmutableSet.copyOf((Iterable)splitter.split((CharSequence)Strings.nullToEmpty((String)((String)headers.getFirst((Object)protocolHeaders.requestClientCapabilities())))));
    }

    private static ResourceEstimates parseResourceEstimate(ProtocolHeaders protocolHeaders, MultivaluedMap<String, String> headers) {
        Session.ResourceEstimateBuilder builder = new Session.ResourceEstimateBuilder();
        HttpRequestSessionContextFactory.parseProperty(headers, protocolHeaders.requestResourceEstimate()).forEach((name, value) -> {
            try {
                switch (name.toUpperCase(Locale.ENGLISH)) {
                    case "EXECUTION_TIME": {
                        builder.setExecutionTime(Duration.valueOf((String)value));
                        return;
                    }
                    case "CPU_TIME": {
                        builder.setCpuTime(Duration.valueOf((String)value));
                        return;
                    }
                    case "PEAK_MEMORY": {
                        builder.setPeakMemory(DataSize.valueOf((String)value));
                        return;
                    }
                }
                throw HttpRequestSessionContextFactory.badRequest(String.format("Unsupported resource name %s", name));
            }
            catch (IllegalArgumentException e) {
                throw HttpRequestSessionContextFactory.badRequest(String.format("Unsupported format for resource estimate '%s': %s", value, e));
            }
        });
        return builder.build();
    }

    @FormatMethod
    private static void assertRequest(boolean expression, String format, Object ... args) {
        if (!expression) {
            throw HttpRequestSessionContextFactory.badRequest(String.format(format, args));
        }
    }

    private Map<String, String> parsePreparedStatementsHeaders(ProtocolHeaders protocolHeaders, MultivaluedMap<String, String> headers) {
        ImmutableMap.Builder preparedStatements = ImmutableMap.builder();
        HttpRequestSessionContextFactory.parseProperty(headers, protocolHeaders.requestPreparedStatement()).forEach((key, value) -> {
            String statementName;
            try {
                statementName = HttpRequestSessionContextFactory.urlDecode(key);
            }
            catch (IllegalArgumentException e) {
                throw HttpRequestSessionContextFactory.badRequest(String.format("Invalid %s header: %s", protocolHeaders.requestPreparedStatement(), e.getMessage()));
            }
            String sqlString = this.preparedStatementEncoder.decodePreparedStatementFromHeader((String)value);
            SqlParser sqlParser = new SqlParser();
            try {
                sqlParser.createStatement(sqlString);
            }
            catch (ParsingException e) {
                throw HttpRequestSessionContextFactory.badRequest(String.format("Invalid %s header: %s", protocolHeaders.requestPreparedStatement(), e.getMessage()));
            }
            preparedStatements.put((Object)statementName, (Object)sqlString);
        });
        return preparedStatements.buildOrThrow();
    }

    private static Optional<TransactionId> parseTransactionId(String transactionId) {
        if ((transactionId = HttpRequestSessionContextFactory.trimEmptyToNull(transactionId)) == null || transactionId.equalsIgnoreCase("none")) {
            return Optional.empty();
        }
        try {
            return Optional.of(TransactionId.valueOf(transactionId));
        }
        catch (Exception e) {
            throw HttpRequestSessionContextFactory.badRequest(e.getMessage());
        }
    }

    private static WebApplicationException badRequest(String message) {
        throw new WebApplicationException(message, Response.status((Response.Status)Response.Status.BAD_REQUEST).type("text/plain").entity((Object)message).build());
    }

    private static String trimEmptyToNull(String value) {
        return Strings.emptyToNull((String)Strings.nullToEmpty((String)value).trim());
    }

    private static String urlDecode(String value) {
        return URLDecoder.decode(value, StandardCharsets.UTF_8);
    }
}

