/*
 * Decompiled with CFR 0.152.
 */
package io.trino.server.ui;

import com.google.common.base.MoreObjects;
import com.google.common.collect.ImmutableSet;
import com.google.inject.Inject;
import io.airlift.log.Logger;
import io.trino.server.ServletSecurityUtils;
import io.trino.server.security.UserMapping;
import io.trino.server.security.UserMappingException;
import io.trino.server.security.oauth2.ChallengeFailedException;
import io.trino.server.security.oauth2.ForRefreshTokens;
import io.trino.server.security.oauth2.OAuth2Client;
import io.trino.server.security.oauth2.OAuth2Config;
import io.trino.server.security.oauth2.OAuth2Service;
import io.trino.server.security.oauth2.TokenPairSerializer;
import io.trino.server.ui.FormWebUiAuthenticationFilter;
import io.trino.server.ui.OAuthIdTokenCookie;
import io.trino.server.ui.OAuthWebUiCookie;
import io.trino.server.ui.WebUiAuthenticationFilter;
import io.trino.spi.security.BasicPrincipal;
import io.trino.spi.security.Identity;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.core.Response;
import java.net.URI;
import java.security.Principal;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.TemporalAmount;
import java.util.Collection;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class OAuth2WebUiAuthenticationFilter
implements WebUiAuthenticationFilter {
    private static final Logger LOG = Logger.get(OAuth2WebUiAuthenticationFilter.class);
    private final String principalField;
    private final OAuth2Service service;
    private final OAuth2Client client;
    private final TokenPairSerializer tokenPairSerializer;
    private final Optional<Duration> tokenExpiration;
    private final UserMapping userMapping;
    private final Optional<String> groupsField;

    @Inject
    public OAuth2WebUiAuthenticationFilter(OAuth2Service service, OAuth2Client client, TokenPairSerializer tokenPairSerializer, @ForRefreshTokens Optional<Duration> tokenExpiration, OAuth2Config oauth2Config) {
        this.service = Objects.requireNonNull(service, "service is null");
        this.client = Objects.requireNonNull(client, "client is null");
        this.tokenPairSerializer = Objects.requireNonNull(tokenPairSerializer, "tokenPairSerializer is null");
        this.tokenExpiration = Objects.requireNonNull(tokenExpiration, "tokenExpiration is null");
        this.userMapping = UserMapping.createUserMapping(oauth2Config.getUserMappingPattern(), oauth2Config.getUserMappingFile());
        this.principalField = oauth2Config.getPrincipalField();
        this.groupsField = Objects.requireNonNull(oauth2Config.getGroupsField(), "groupsField is null");
    }

    public void filter(ContainerRequestContext request) {
        String path = request.getUriInfo().getRequestUri().getPath();
        if (path.equals("/ui/disabled.html")) {
            return;
        }
        if (!request.getSecurityContext().isSecure()) {
            if (path.startsWith("/ui/api/")) {
                ServletSecurityUtils.sendWwwAuthenticate(request, "Unauthorized", (Collection<String>)ImmutableSet.of((Object)"Trino-Form-Login"));
                return;
            }
            request.abortWith(Response.seeOther((URI)FormWebUiAuthenticationFilter.DISABLED_LOCATION_URI).build());
            return;
        }
        Optional<TokenPairSerializer.TokenPair> tokenPair = this.getTokenPair(request);
        Optional claims = tokenPair.filter(this::tokenNotExpired).flatMap(this::getAccessTokenClaims);
        if (claims.isEmpty()) {
            this.needAuthentication(request, tokenPair);
            return;
        }
        try {
            Object principal = ((Map)claims.get()).get(this.principalField);
            if (!OAuth2WebUiAuthenticationFilter.isValidPrincipal(principal)) {
                LOG.debug("Invalid principal field: %s. Expected principal to be non-empty", new Object[]{this.principalField});
                ServletSecurityUtils.sendErrorMessage(request, Response.Status.UNAUTHORIZED, "Unauthorized");
                return;
            }
            String principalName = (String)principal;
            Identity.Builder builder = Identity.forUser((String)this.userMapping.mapUser(principalName));
            builder.withPrincipal((Principal)new BasicPrincipal(principalName));
            this.groupsField.flatMap(field -> Optional.ofNullable((List)((Map)claims.get()).get(field))).ifPresent(groups -> builder.withGroups((Set)ImmutableSet.copyOf((Collection)groups)));
            ServletSecurityUtils.setAuthenticatedIdentity(request, builder.build());
        }
        catch (UserMappingException e) {
            ServletSecurityUtils.sendErrorMessage(request, Response.Status.UNAUTHORIZED, (String)MoreObjects.firstNonNull((Object)e.getMessage(), (Object)"Unauthorized"));
        }
    }

    private Optional<TokenPairSerializer.TokenPair> getTokenPair(ContainerRequestContext request) {
        try {
            return OAuthWebUiCookie.read(request.getCookies()).map(this.tokenPairSerializer::deserialize);
        }
        catch (Exception e) {
            LOG.debug((Throwable)e, "Exception occurred during token pair deserialization");
            return Optional.empty();
        }
    }

    private boolean tokenNotExpired(TokenPairSerializer.TokenPair tokenPair) {
        return tokenPair.expiration().after(Date.from(Instant.now()));
    }

    private Optional<Map<String, Object>> getAccessTokenClaims(TokenPairSerializer.TokenPair tokenPair) {
        return this.client.getClaims(tokenPair.accessToken());
    }

    private void needAuthentication(ContainerRequestContext request, Optional<TokenPairSerializer.TokenPair> tokenPair) {
        Optional refreshToken = tokenPair.flatMap(TokenPairSerializer.TokenPair::refreshToken);
        if (refreshToken.isPresent()) {
            try {
                this.redirectForNewToken(request, (String)refreshToken.get());
                return;
            }
            catch (Exception e) {
                LOG.debug((Throwable)e, "Tokens refresh challenge has failed");
            }
        }
        this.handleAuthenticationFailure(request);
    }

    private void redirectForNewToken(ContainerRequestContext request, String refreshToken) throws ChallengeFailedException {
        OAuth2Client.Response response = this.client.refreshTokens(refreshToken);
        String serializedToken = this.tokenPairSerializer.serialize(TokenPairSerializer.TokenPair.fromOAuth2Response(response));
        Instant newExpirationTime = this.tokenExpiration.map(expiration -> Instant.now().plus((TemporalAmount)expiration)).orElse(response.getExpiration());
        Response.ResponseBuilder builder = Response.temporaryRedirect((URI)request.getUriInfo().getRequestUri()).cookie(OAuthWebUiCookie.create(serializedToken, newExpirationTime));
        OAuthIdTokenCookie.read(request.getCookies()).ifPresent(idToken -> builder.cookie(OAuthIdTokenCookie.create(idToken, newExpirationTime)));
        request.abortWith(builder.build());
    }

    private void handleAuthenticationFailure(ContainerRequestContext request) {
        if (request.getUriInfo().getRequestUri().getPath().startsWith("/ui/api/")) {
            ServletSecurityUtils.sendWwwAuthenticate(request, "Unauthorized", (Collection<String>)ImmutableSet.of((Object)"Trino-Form-Login"));
        } else {
            this.startOAuth2Challenge(request);
        }
    }

    private void startOAuth2Challenge(ContainerRequestContext request) {
        request.abortWith(this.service.startOAuth2Challenge(request.getUriInfo().getBaseUri().resolve("/oauth2/callback"), Optional.empty()));
    }

    private static boolean isValidPrincipal(Object principal) {
        return principal instanceof String && !((String)principal).isEmpty();
    }
}

