/*
 * Decompiled with CFR 0.152.
 */
package io.trino.server.ui;

import com.google.common.base.MoreObjects;
import com.google.common.collect.ImmutableSet;
import io.airlift.log.Logger;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.Jwt;
import io.jsonwebtoken.JwtException;
import io.trino.server.ServletSecurityUtils;
import io.trino.server.security.UserMapping;
import io.trino.server.security.UserMappingException;
import io.trino.server.security.oauth2.NonceCookie;
import io.trino.server.security.oauth2.OAuth2Config;
import io.trino.server.security.oauth2.OAuth2Service;
import io.trino.server.ui.FormWebUiAuthenticationFilter;
import io.trino.server.ui.OAuthWebUiCookie;
import io.trino.server.ui.WebUiAuthenticationFilter;
import io.trino.spi.security.BasicPrincipal;
import io.trino.spi.security.Identity;
import java.net.URI;
import java.security.Principal;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import javax.inject.Inject;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.core.Cookie;
import javax.ws.rs.core.NewCookie;
import javax.ws.rs.core.Response;

public class OAuth2WebUiAuthenticationFilter
implements WebUiAuthenticationFilter {
    private static final Logger LOG = Logger.get(OAuth2WebUiAuthenticationFilter.class);
    private final String principalField;
    private final OAuth2Service service;
    private final UserMapping userMapping;
    private final Optional<String> validAudience;

    @Inject
    public OAuth2WebUiAuthenticationFilter(OAuth2Service service, OAuth2Config oauth2Config) {
        this.service = Objects.requireNonNull(service, "service is null");
        Objects.requireNonNull(oauth2Config, "oauth2Config is null");
        this.userMapping = UserMapping.createUserMapping(oauth2Config.getUserMappingPattern(), oauth2Config.getUserMappingFile());
        this.validAudience = oauth2Config.getAudience();
        this.principalField = oauth2Config.getPrincipalField();
    }

    public void filter(ContainerRequestContext request) {
        String path = request.getUriInfo().getRequestUri().getPath();
        if (path.equals("/ui/disabled.html")) {
            return;
        }
        if (!request.getSecurityContext().isSecure()) {
            if (path.startsWith("/ui/api/")) {
                ServletSecurityUtils.sendWwwAuthenticate(request, "Unauthorized", (Collection<String>)ImmutableSet.of((Object)"Trino-Form-Login"));
                return;
            }
            request.abortWith(Response.seeOther((URI)FormWebUiAuthenticationFilter.DISABLED_LOCATION_URI).build());
            return;
        }
        Optional<Claims> claims = this.getAccessToken(request).map(Jwt::getBody);
        if (claims.isEmpty()) {
            this.needAuthentication(request);
            return;
        }
        Object audience = claims.get().get((Object)"aud");
        if (!this.hasValidAudience(audience)) {
            LOG.debug("Invalid audience: %s. Expected audience to be equal to or contain: %s", new Object[]{audience, this.validAudience});
            ServletSecurityUtils.sendErrorMessage(request, Response.Status.UNAUTHORIZED, "Unauthorized");
            return;
        }
        try {
            Object principal = claims.get().get((Object)this.principalField);
            if (!this.isValidPrincipal(principal)) {
                LOG.debug("Invalid principal field: %s. Expected principal to be non-empty", new Object[]{this.principalField});
                ServletSecurityUtils.sendErrorMessage(request, Response.Status.UNAUTHORIZED, "Unauthorized");
                return;
            }
            String principalName = (String)principal;
            ServletSecurityUtils.setAuthenticatedIdentity(request, Identity.forUser((String)this.userMapping.mapUser(principalName)).withPrincipal((Principal)new BasicPrincipal(principalName)).build());
        }
        catch (UserMappingException e) {
            ServletSecurityUtils.sendErrorMessage(request, Response.Status.UNAUTHORIZED, (String)MoreObjects.firstNonNull((Object)e.getMessage(), (Object)"Unauthorized"));
        }
    }

    private Optional<Jws<Claims>> getAccessToken(ContainerRequestContext request) {
        return OAuthWebUiCookie.read((Cookie)request.getCookies().get("__Secure-Trino-OAuth2-Token")).flatMap(token -> {
            try {
                return Optional.ofNullable(this.service.parseClaimsJws((String)token));
            }
            catch (JwtException | IllegalArgumentException e) {
                LOG.debug("Unable to parse JWT token: " + e.getMessage(), new Object[]{e});
                return Optional.empty();
            }
        });
    }

    private void needAuthentication(ContainerRequestContext request) {
        if (request.getUriInfo().getRequestUri().getPath().startsWith("/ui/api/")) {
            ServletSecurityUtils.sendWwwAuthenticate(request, "Unauthorized", (Collection<String>)ImmutableSet.of((Object)"Trino-Form-Login"));
            return;
        }
        OAuth2Service.OAuthChallenge challenge = this.service.startWebUiChallenge(request.getUriInfo().getBaseUri().resolve("/oauth2/callback"));
        Response.ResponseBuilder response = Response.seeOther((URI)challenge.getRedirectUrl());
        challenge.getNonce().ifPresent(nonce -> response.cookie(new NewCookie[]{NonceCookie.create(nonce, challenge.getChallengeExpiration())}));
        request.abortWith(response.build());
    }

    private boolean hasValidAudience(Object audience) {
        if (this.validAudience.isEmpty()) {
            return true;
        }
        if (audience == null) {
            return false;
        }
        if (audience instanceof String) {
            return audience.equals(this.validAudience.get());
        }
        if (audience instanceof List) {
            return ((List)audience).contains(this.validAudience.get());
        }
        return false;
    }

    private boolean isValidPrincipal(Object principal) {
        return principal instanceof String && !((String)principal).isEmpty();
    }
}

