/*
 * Decompiled with CFR 0.152.
 */
package io.trino.server.security.oauth2;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Ordering;
import com.google.common.hash.Hashing;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.util.ResourceRetriever;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTClaimsSetVerifier;
import com.nimbusds.jwt.proc.JWTProcessor;
import com.nimbusds.oauth2.sdk.AccessTokenResponse;
import com.nimbusds.oauth2.sdk.AuthorizationCode;
import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant;
import com.nimbusds.oauth2.sdk.AuthorizationGrant;
import com.nimbusds.oauth2.sdk.AuthorizationRequest;
import com.nimbusds.oauth2.sdk.Request;
import com.nimbusds.oauth2.sdk.ResponseType;
import com.nimbusds.oauth2.sdk.Scope;
import com.nimbusds.oauth2.sdk.TokenRequest;
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.oauth2.sdk.auth.Secret;
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.id.Issuer;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.oauth2.sdk.token.AccessToken;
import com.nimbusds.oauth2.sdk.token.BearerAccessToken;
import com.nimbusds.oauth2.sdk.token.Tokens;
import com.nimbusds.openid.connect.sdk.AuthenticationRequest;
import com.nimbusds.openid.connect.sdk.Nonce;
import com.nimbusds.openid.connect.sdk.OIDCScopeValue;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponse;
import com.nimbusds.openid.connect.sdk.UserInfoRequest;
import com.nimbusds.openid.connect.sdk.UserInfoResponse;
import com.nimbusds.openid.connect.sdk.claims.AccessTokenHash;
import com.nimbusds.openid.connect.sdk.claims.IDTokenClaimsSet;
import com.nimbusds.openid.connect.sdk.token.OIDCTokens;
import com.nimbusds.openid.connect.sdk.validators.AccessTokenValidator;
import com.nimbusds.openid.connect.sdk.validators.IDTokenValidator;
import com.nimbusds.openid.connect.sdk.validators.InvalidHashException;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.trino.server.security.oauth2.ChallengeFailedException;
import io.trino.server.security.oauth2.NimbusAirliftHttpClient;
import io.trino.server.security.oauth2.NimbusHttpClient;
import io.trino.server.security.oauth2.OAuth2Client;
import io.trino.server.security.oauth2.OAuth2Config;
import io.trino.server.security.oauth2.OAuth2ServerConfigProvider;
import java.net.MalformedURLException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.text.ParseException;
import java.time.Instant;
import java.util.Date;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import javax.inject.Inject;

public class NimbusOAuth2Client
implements OAuth2Client {
    private static final Logger LOG = Logger.get(NimbusAirliftHttpClient.class);
    private final Issuer issuer;
    private final ClientID clientId;
    private final ClientSecretBasic clientAuth;
    private final Scope scope;
    private final String principalField;
    private final Set<String> accessTokenAudiences;
    private final Duration maxClockSkew;
    private final NimbusHttpClient httpClient;
    private final OAuth2ServerConfigProvider serverConfigurationProvider;
    private volatile boolean loaded;
    private URI authUrl;
    private URI tokenUrl;
    private Optional<URI> userinfoUrl;
    private JWSKeySelector<SecurityContext> jwsKeySelector;
    private JWTProcessor<SecurityContext> accessTokenProcessor;
    private AuthorizationCodeFlow flow;

    @Inject
    public NimbusOAuth2Client(OAuth2Config oauthConfig, OAuth2ServerConfigProvider serverConfigurationProvider, NimbusHttpClient httpClient) {
        Objects.requireNonNull(oauthConfig, "oauthConfig is null");
        this.issuer = new Issuer(oauthConfig.getIssuer());
        this.clientId = new ClientID(oauthConfig.getClientId());
        this.clientAuth = new ClientSecretBasic(this.clientId, new Secret(oauthConfig.getClientSecret()));
        this.scope = Scope.parse(oauthConfig.getScopes());
        this.principalField = oauthConfig.getPrincipalField();
        this.maxClockSkew = oauthConfig.getMaxClockSkew();
        this.accessTokenAudiences = new HashSet<String>(oauthConfig.getAdditionalAudiences());
        this.accessTokenAudiences.add(this.clientId.getValue());
        this.accessTokenAudiences.add(null);
        this.serverConfigurationProvider = Objects.requireNonNull(serverConfigurationProvider, "serverConfigurationProvider is null");
        this.httpClient = Objects.requireNonNull(httpClient, "httpClient is null");
    }

    @Override
    public void load() {
        OAuth2ServerConfigProvider.OAuth2ServerConfig config = this.serverConfigurationProvider.get();
        this.authUrl = config.getAuthUrl();
        this.tokenUrl = config.getTokenUrl();
        this.userinfoUrl = config.getUserinfoUrl();
        try {
            this.jwsKeySelector = new JWSVerificationKeySelector((Set)Stream.concat(JWSAlgorithm.Family.RSA.stream(), JWSAlgorithm.Family.EC.stream()).collect(ImmutableSet.toImmutableSet()), (JWKSource)new RemoteJWKSet(config.getJwksUrl().toURL(), (ResourceRetriever)this.httpClient));
        }
        catch (MalformedURLException e) {
            throw new RuntimeException(e);
        }
        DefaultJWTProcessor processor = new DefaultJWTProcessor();
        processor.setJWSKeySelector(this.jwsKeySelector);
        DefaultJWTClaimsVerifier accessTokenVerifier = new DefaultJWTClaimsVerifier(this.accessTokenAudiences, new JWTClaimsSet.Builder().issuer(config.getAccessTokenIssuer().orElse(this.issuer.getValue())).build(), (Set)ImmutableSet.of((Object)this.principalField), (Set)ImmutableSet.of());
        accessTokenVerifier.setMaxClockSkew((int)this.maxClockSkew.roundTo(TimeUnit.SECONDS));
        processor.setJWTClaimsSetVerifier((JWTClaimsSetVerifier)accessTokenVerifier);
        this.accessTokenProcessor = processor;
        this.flow = this.scope.contains((Object)OIDCScopeValue.OPENID) ? new OAuth2WithOidcExtensionsCodeFlow() : new OAuth2AuthorizationCodeFlow();
        this.loaded = true;
    }

    @Override
    public OAuth2Client.Request createAuthorizationRequest(String state, URI callbackUri) {
        Preconditions.checkState((boolean)this.loaded, (Object)"OAuth2 client not initialized");
        return this.flow.createAuthorizationRequest(state, callbackUri);
    }

    @Override
    public OAuth2Client.Response getOAuth2Response(String code, URI callbackUri, Optional<String> nonce) throws ChallengeFailedException {
        Preconditions.checkState((boolean)this.loaded, (Object)"OAuth2 client not initialized");
        return this.flow.getOAuth2Response(code, callbackUri, nonce);
    }

    @Override
    public Optional<Map<String, Object>> getClaims(String accessToken) {
        Preconditions.checkState((boolean)this.loaded, (Object)"OAuth2 client not initialized");
        return this.getJWTClaimsSet(accessToken).map(JWTClaimsSet::getClaims);
    }

    private <T extends AccessTokenResponse> T getTokenResponse(String code, URI callbackUri, NimbusHttpClient.Parser<T> parser) throws ChallengeFailedException {
        AccessTokenResponse tokenResponse = (AccessTokenResponse)this.httpClient.execute((Request)new TokenRequest(this.tokenUrl, (ClientAuthentication)this.clientAuth, (AuthorizationGrant)new AuthorizationCodeGrant(new AuthorizationCode(code), callbackUri)), parser);
        if (!tokenResponse.indicatesSuccess()) {
            throw new ChallengeFailedException("Error while fetching access token: " + tokenResponse.toErrorResponse().toJSONObject());
        }
        return (T)tokenResponse;
    }

    private Optional<JWTClaimsSet> getJWTClaimsSet(String accessToken) {
        if (this.userinfoUrl.isPresent()) {
            return this.queryUserInfo(accessToken);
        }
        return this.parseAccessToken(accessToken);
    }

    private Optional<JWTClaimsSet> queryUserInfo(String accessToken) {
        try {
            UserInfoResponse response = this.httpClient.execute((Request)new UserInfoRequest(this.userinfoUrl.get(), (AccessToken)new BearerAccessToken(accessToken)), UserInfoResponse::parse);
            if (!response.indicatesSuccess()) {
                LOG.error("Received bad response from userinfo endpoint: " + response.toErrorResponse().getErrorObject());
                return Optional.empty();
            }
            return Optional.of(response.toSuccessResponse().getUserInfo().toJWTClaimsSet());
        }
        catch (com.nimbusds.oauth2.sdk.ParseException | RuntimeException e) {
            LOG.error(e, "Received bad response from userinfo endpoint");
            return Optional.empty();
        }
    }

    private Optional<JWTClaimsSet> parseAccessToken(String accessToken) {
        try {
            return Optional.of(this.accessTokenProcessor.process(accessToken, null));
        }
        catch (JOSEException | BadJOSEException | ParseException e) {
            LOG.error(e, "Failed to parse JWT access token");
            return Optional.empty();
        }
    }

    private static Instant determineExpiration(Optional<Instant> validUntil, Date expiration) throws ChallengeFailedException {
        if (validUntil.isPresent()) {
            if (expiration != null) {
                return (Instant)Ordering.natural().min((Object)validUntil.get(), (Object)expiration.toInstant());
            }
            return validUntil.get();
        }
        if (expiration != null) {
            return expiration.toInstant();
        }
        throw new ChallengeFailedException("no valid expiration date");
    }

    private static Optional<Instant> getExpiration(AccessToken accessToken) {
        return accessToken.getLifetime() != 0L ? Optional.of(Instant.now().plusSeconds(accessToken.getLifetime())) : Optional.empty();
    }

    private class OAuth2WithOidcExtensionsCodeFlow
    implements AuthorizationCodeFlow {
        private final IDTokenValidator idTokenValidator;

        public OAuth2WithOidcExtensionsCodeFlow() {
            this.idTokenValidator = new IDTokenValidator(NimbusOAuth2Client.this.issuer, NimbusOAuth2Client.this.clientId, NimbusOAuth2Client.this.jwsKeySelector, null);
            this.idTokenValidator.setMaxClockSkew((int)NimbusOAuth2Client.this.maxClockSkew.roundTo(TimeUnit.SECONDS));
        }

        @Override
        public OAuth2Client.Request createAuthorizationRequest(String state, URI callbackUri) {
            String nonce = new Nonce().getValue();
            return new OAuth2Client.Request(new AuthenticationRequest.Builder(ResponseType.CODE, NimbusOAuth2Client.this.scope, NimbusOAuth2Client.this.clientId, callbackUri).endpointURI(NimbusOAuth2Client.this.authUrl).state(new State(state)).nonce(new Nonce(this.hashNonce(nonce))).build().toURI(), Optional.of(nonce));
        }

        @Override
        public OAuth2Client.Response getOAuth2Response(String code, URI callbackUri, Optional<String> nonce) throws ChallengeFailedException {
            OIDCTokenResponse tokenResponse = NimbusOAuth2Client.this.getTokenResponse(code, callbackUri, OIDCTokenResponse::parse);
            OIDCTokens tokens = tokenResponse.getOIDCTokens();
            this.validateTokens(tokens, nonce);
            AccessToken accessToken = tokens.getAccessToken();
            JWTClaimsSet claims = NimbusOAuth2Client.this.getJWTClaimsSet(accessToken.getValue()).orElseThrow(() -> new ChallengeFailedException("invalid access token"));
            return new OAuth2Client.Response(accessToken.getValue(), NimbusOAuth2Client.determineExpiration(NimbusOAuth2Client.getExpiration(accessToken), claims.getExpirationTime()), Optional.ofNullable(tokens.getIDTokenString()));
        }

        private void validateTokens(OIDCTokens tokens, Optional<String> nonce) throws ChallengeFailedException {
            try {
                IDTokenClaimsSet idToken = this.idTokenValidator.validate(tokens.getIDToken(), nonce.map(this::hashNonce).map(Nonce::new).orElseThrow(() -> new ChallengeFailedException("Missing nonce")));
                AccessTokenHash accessTokenHash = idToken.getAccessTokenHash();
                if (accessTokenHash != null) {
                    AccessTokenValidator.validate((AccessToken)tokens.getAccessToken(), (JWSAlgorithm)((JWSHeader)tokens.getIDToken().getHeader()).getAlgorithm(), (AccessTokenHash)accessTokenHash);
                }
            }
            catch (JOSEException | BadJOSEException | InvalidHashException e) {
                throw new ChallengeFailedException("Cannot validate nonce parameter", e);
            }
        }

        private String hashNonce(String nonce) {
            return Hashing.sha256().hashString((CharSequence)nonce, StandardCharsets.UTF_8).toString();
        }
    }

    private class OAuth2AuthorizationCodeFlow
    implements AuthorizationCodeFlow {
        private OAuth2AuthorizationCodeFlow() {
        }

        @Override
        public OAuth2Client.Request createAuthorizationRequest(String state, URI callbackUri) {
            return new OAuth2Client.Request(new AuthorizationRequest.Builder(ResponseType.CODE, NimbusOAuth2Client.this.clientId).redirectionURI(callbackUri).scope(NimbusOAuth2Client.this.scope).endpointURI(NimbusOAuth2Client.this.authUrl).state(new State(state)).build().toURI(), Optional.empty());
        }

        @Override
        public OAuth2Client.Response getOAuth2Response(String code, URI callbackUri, Optional<String> nonce) throws ChallengeFailedException {
            Preconditions.checkArgument((boolean)nonce.isEmpty(), (Object)"Unexpected nonce provided");
            AccessTokenResponse tokenResponse = NimbusOAuth2Client.this.getTokenResponse(code, callbackUri, AccessTokenResponse::parse);
            Tokens tokens = tokenResponse.toSuccessResponse().getTokens();
            AccessToken accessToken = tokens.getAccessToken();
            JWTClaimsSet claims = NimbusOAuth2Client.this.getJWTClaimsSet(accessToken.getValue()).orElseThrow(() -> new ChallengeFailedException("invalid access token"));
            return new OAuth2Client.Response(accessToken.getValue(), NimbusOAuth2Client.determineExpiration(NimbusOAuth2Client.getExpiration(accessToken), claims.getExpirationTime()), Optional.empty());
        }
    }

    private static interface AuthorizationCodeFlow {
        public OAuth2Client.Request createAuthorizationRequest(String var1, URI var2);

        public OAuth2Client.Response getOAuth2Response(String var1, URI var2, Optional<String> var3) throws ChallengeFailedException;
    }
}

