/*
 * Decompiled with CFR 0.152.
 */
package io.trino.server.security.oauth2;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Ordering;
import com.google.common.hash.Hashing;
import com.google.common.io.BaseEncoding;
import com.google.common.io.Resources;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.JsonResponseHandler;
import io.airlift.http.client.Request;
import io.airlift.json.JsonCodec;
import io.airlift.log.Logger;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwtParser;
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.impl.DefaultClaims;
import io.jsonwebtoken.security.Keys;
import io.trino.server.security.jwt.JwtUtil;
import io.trino.server.security.oauth2.ChallengeFailedException;
import io.trino.server.security.oauth2.ForOAuth2;
import io.trino.server.security.oauth2.NonceCookie;
import io.trino.server.security.oauth2.OAuth2Client;
import io.trino.server.security.oauth2.OAuth2Config;
import io.trino.server.security.oauth2.OAuth2TokenHandler;
import io.trino.server.ui.OAuth2WebUiInstalled;
import io.trino.server.ui.OAuthWebUiCookie;
import java.io.IOException;
import java.net.URI;
import java.net.URL;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.Key;
import java.security.SecureRandom;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.TemporalAmount;
import java.util.Collection;
import java.util.Date;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import javax.inject.Inject;
import javax.ws.rs.core.NewCookie;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.UriBuilder;

public class OAuth2Service {
    private static final Logger LOG = Logger.get(OAuth2Service.class);
    public static final String REDIRECT_URI = "redirect_uri";
    public static final String STATE = "state";
    public static final String NONCE = "nonce";
    public static final String OPENID_SCOPE = "openid";
    private static final String STATE_AUDIENCE_UI = "trino_oauth_ui";
    private static final String FAILURE_REPLACEMENT_TEXT = "<!-- ERROR_MESSAGE -->";
    private static final Random SECURE_RANDOM = new SecureRandom();
    public static final String HANDLER_STATE_CLAIM = "handler_state";
    private static final JsonResponseHandler<Map<String, Object>> USERINFO_RESPONSE_HANDLER = JsonResponseHandler.createJsonResponseHandler((JsonCodec)JsonCodec.mapJsonCodec(String.class, Object.class));
    private final OAuth2Client client;
    private final SigningKeyResolver signingKeyResolver;
    private final String successHtml;
    private final String failureHtml;
    private final Set<String> scopes;
    private final TemporalAmount challengeTimeout;
    private final Key stateHmac;
    private final JwtParser jwtParser;
    private final HttpClient httpClient;
    private final String issuer;
    private final String accessTokenIssuer;
    private final String clientId;
    private final Optional<URI> userinfoUri;
    private final Set<String> allowedAudiences;
    private final OAuth2TokenHandler tokenHandler;
    private final boolean webUiOAuthEnabled;

    @Inject
    public OAuth2Service(OAuth2Client client, @ForOAuth2 SigningKeyResolver signingKeyResolver, @ForOAuth2 HttpClient httpClient, OAuth2Config oauth2Config, OAuth2TokenHandler tokenHandler, Optional<OAuth2WebUiInstalled> webUiOAuthEnabled) throws IOException {
        this.client = Objects.requireNonNull(client, "client is null");
        this.signingKeyResolver = Objects.requireNonNull(signingKeyResolver, "signingKeyResolver is null");
        Objects.requireNonNull(oauth2Config, "oauth2Config is null");
        this.successHtml = Resources.toString((URL)Resources.getResource(this.getClass(), (String)"/oauth2/success.html"), (Charset)StandardCharsets.UTF_8);
        this.failureHtml = Resources.toString((URL)Resources.getResource(this.getClass(), (String)"/oauth2/failure.html"), (Charset)StandardCharsets.UTF_8);
        Verify.verify((boolean)this.failureHtml.contains(FAILURE_REPLACEMENT_TEXT), (String)"login.html does not contain the replacement text", (Object[])new Object[0]);
        this.scopes = oauth2Config.getScopes();
        this.challengeTimeout = Duration.ofMillis(oauth2Config.getChallengeTimeout().toMillis());
        this.stateHmac = Keys.hmacShaKeyFor((byte[])oauth2Config.getStateKey().map(key -> Hashing.sha256().hashString((CharSequence)key, StandardCharsets.UTF_8).asBytes()).orElseGet(() -> OAuth2Service.secureRandomBytes(32)));
        this.jwtParser = JwtUtil.newJwtParserBuilder().setSigningKey(this.stateHmac).requireAudience(STATE_AUDIENCE_UI).build();
        this.httpClient = Objects.requireNonNull(httpClient, "httpClient is null");
        this.issuer = oauth2Config.getIssuer();
        this.accessTokenIssuer = oauth2Config.getAccessTokenIssuer().orElse(this.issuer);
        this.clientId = oauth2Config.getClientId();
        this.userinfoUri = oauth2Config.getUserinfoUrl().map(url -> UriBuilder.fromUri((String)url).build(new Object[0]));
        this.allowedAudiences = ImmutableSet.builder().addAll(oauth2Config.getAdditionalAudiences()).add((Object)this.clientId).build();
        this.tokenHandler = Objects.requireNonNull(tokenHandler, "tokenHandler is null");
        this.webUiOAuthEnabled = Objects.requireNonNull(webUiOAuthEnabled, "webUiOAuthEnabled is null").isPresent();
    }

    public Response startOAuth2Challenge(URI callbackUri, Optional<String> handlerState) {
        Instant challengeExpiration = Instant.now().plus(this.challengeTimeout);
        String state = JwtUtil.newJwtBuilder().signWith(this.stateHmac).setAudience(STATE_AUDIENCE_UI).claim(HANDLER_STATE_CLAIM, handlerState.orElse(null)).setExpiration(Date.from(challengeExpiration)).compact();
        Optional<Object> nonce = this.scopes.contains(OPENID_SCOPE) ? Optional.of(OAuth2Service.randomNonce()) : Optional.empty();
        Response.ResponseBuilder response = Response.seeOther((URI)this.client.getAuthorizationUri(state, callbackUri, nonce.map(OAuth2Service::hashNonce)));
        nonce.ifPresent(nce -> response.cookie(new NewCookie[]{NonceCookie.create(nce, challengeExpiration)}));
        return response.build();
    }

    public Response handleOAuth2Error(String state, String error, String errorDescription, String errorUri) {
        try {
            Claims stateClaims = this.parseState(state);
            Optional.ofNullable((String)stateClaims.get(HANDLER_STATE_CLAIM, String.class)).ifPresent(value -> this.tokenHandler.setTokenExchangeError((String)value, String.format("Authentication response could not be verified: error=%s, errorDescription=%s, errorUri=%s", error, errorDescription, errorDescription)));
        }
        catch (ChallengeFailedException | RuntimeException e) {
            LOG.debug((Throwable)e, "Authentication response could not be verified invalid state: state=%s", new Object[]{state});
            return Response.status((Response.Status)Response.Status.BAD_REQUEST).entity((Object)this.getInternalFailureHtml("Authentication response could not be verified")).cookie(new NewCookie[]{NonceCookie.delete()}).build();
        }
        LOG.debug("OAuth server returned an error: error=%s, error_description=%s, error_uri=%s, state=%s", new Object[]{error, errorDescription, errorUri, state});
        return Response.ok().entity((Object)this.getCallbackErrorHtml(error)).cookie(new NewCookie[]{NonceCookie.delete()}).build();
    }

    public Response finishOAuth2Challenge(String state, String code, URI callbackUri, Optional<String> nonce) {
        Optional<String> handlerState;
        try {
            Claims stateClaims = this.parseState(state);
            handlerState = Optional.ofNullable((String)stateClaims.get(HANDLER_STATE_CLAIM, String.class));
        }
        catch (ChallengeFailedException | RuntimeException e) {
            LOG.debug((Throwable)e, "Authentication response could not be verified invalid state: state=%s", new Object[]{state});
            return Response.status((Response.Status)Response.Status.BAD_REQUEST).entity((Object)this.getInternalFailureHtml("Authentication response could not be verified")).cookie(new NewCookie[]{NonceCookie.delete()}).build();
        }
        try {
            OAuth2Client.OAuth2Response oauth2Response = this.client.getOAuth2Response(code, callbackUri);
            Claims parsedToken = this.validateAndParseOAuth2Response(oauth2Response, nonce).orElseThrow(() -> new ChallengeFailedException("invalid access token"));
            Instant validUntil = OAuth2Service.determineExpiration(oauth2Response.getValidUntil(), parsedToken.getExpiration());
            if (handlerState.isEmpty()) {
                return Response.seeOther((URI)URI.create("/ui/")).cookie(new NewCookie[]{OAuthWebUiCookie.create(oauth2Response.getAccessToken(), validUntil), NonceCookie.delete()}).build();
            }
            this.tokenHandler.setAccessToken(handlerState.get(), oauth2Response.getAccessToken());
            Response.ResponseBuilder builder = Response.ok((Object)this.getSuccessHtml());
            if (this.webUiOAuthEnabled) {
                builder.cookie(new NewCookie[]{OAuthWebUiCookie.create(oauth2Response.getAccessToken(), validUntil)});
            }
            return builder.cookie(new NewCookie[]{NonceCookie.delete()}).build();
        }
        catch (ChallengeFailedException | RuntimeException e) {
            LOG.debug((Throwable)e, "Authentication response could not be verified: state=%s", new Object[]{state});
            handlerState.ifPresent(value -> this.tokenHandler.setTokenExchangeError((String)value, String.format("Authentication response could not be verified: state=%s", value)));
            return Response.status((Response.Status)Response.Status.BAD_REQUEST).cookie(new NewCookie[]{NonceCookie.delete()}).entity((Object)this.getInternalFailureHtml("Authentication response could not be verified")).build();
        }
    }

    private static Instant determineExpiration(Optional<Instant> validUntil, Date expiration) throws ChallengeFailedException {
        if (validUntil.isPresent()) {
            if (expiration != null) {
                return (Instant)Ordering.natural().min((Object)validUntil.get(), (Object)expiration.toInstant());
            }
            return validUntil.get();
        }
        if (expiration != null) {
            return expiration.toInstant();
        }
        throw new ChallengeFailedException("no valid expiration date");
    }

    private Claims parseState(String state) throws ChallengeFailedException {
        try {
            return (Claims)this.jwtParser.parseClaimsJws(state).getBody();
        }
        catch (RuntimeException e) {
            throw new ChallengeFailedException("State validation failed", e);
        }
    }

    private Optional<Claims> validateAndParseOAuth2Response(OAuth2Client.OAuth2Response oauth2Response, Optional<String> nonce) throws ChallengeFailedException {
        this.validateIdTokenAndNonce(oauth2Response, nonce);
        return this.internalConvertTokenToClaims(oauth2Response.getAccessToken());
    }

    private void validateIdTokenAndNonce(OAuth2Client.OAuth2Response oauth2Response, Optional<String> nonce) throws ChallengeFailedException {
        if (nonce.isPresent() && oauth2Response.getIdToken().isPresent()) {
            Claims claims = (Claims)JwtUtil.newJwtParserBuilder().setSigningKeyResolver(this.signingKeyResolver).requireIssuer(this.issuer).require(NONCE, (Object)OAuth2Service.hashNonce(nonce.get())).build().parseClaimsJws(oauth2Response.getIdToken().get()).getBody();
            this.validateAudience(claims, false);
        } else if (nonce.isPresent() != oauth2Response.getIdToken().isPresent()) {
            throw new ChallengeFailedException("Cannot validate nonce parameter");
        }
    }

    public Optional<Map<String, Object>> convertTokenToClaims(String token) throws ChallengeFailedException {
        return this.internalConvertTokenToClaims(token).map(claims -> claims);
    }

    private Optional<Claims> internalConvertTokenToClaims(String accessToken) throws ChallengeFailedException {
        if (this.userinfoUri.isPresent()) {
            Request request = Request.builder().setMethod("POST").addHeader("Authorization", "Bearer " + accessToken).setUri(this.userinfoUri.get()).build();
            try {
                Map userinfoClaims = (Map)this.httpClient.execute(request, USERINFO_RESPONSE_HANDLER);
                DefaultClaims claims = new DefaultClaims(userinfoClaims);
                this.validateAudience((Claims)claims, true);
                return Optional.of(claims);
            }
            catch (RuntimeException e) {
                LOG.error((Throwable)e, "Received bad response from userinfo endpoint");
                return Optional.empty();
            }
        }
        Claims claims = (Claims)JwtUtil.newJwtParserBuilder().setSigningKeyResolver(this.signingKeyResolver).requireIssuer(this.accessTokenIssuer).build().parseClaimsJws(accessToken).getBody();
        this.validateAudience(claims, true);
        return Optional.of(claims);
    }

    private void validateAudience(Claims claims, boolean isAccessToken) throws ChallengeFailedException {
        Set<String> validAudiences;
        Object tokenAudience = claims.get((Object)"aud");
        if (isAccessToken) {
            if (tokenAudience == null || tokenAudience instanceof Collection && ((Collection)tokenAudience).isEmpty()) {
                return;
            }
            validAudiences = this.allowedAudiences;
        } else {
            validAudiences = Set.of(this.clientId);
        }
        if (tokenAudience instanceof String) {
            if (!validAudiences.contains((String)tokenAudience)) {
                throw new ChallengeFailedException(String.format("Invalid Audience: %s. Allowed audiences: %s", tokenAudience, this.allowedAudiences));
            }
        } else if (tokenAudience instanceof Collection) {
            if (((Collection)tokenAudience).stream().map(String.class::cast).noneMatch(validAudiences::contains)) {
                throw new ChallengeFailedException(String.format("Invalid Audience: %s. Allowed audiences: %s", tokenAudience, this.allowedAudiences));
            }
        } else {
            throw new ChallengeFailedException(String.format("Invalid Audience: %s", tokenAudience));
        }
    }

    public String getSuccessHtml() {
        return this.successHtml;
    }

    public String getCallbackErrorHtml(String errorCode) {
        return this.failureHtml.replace(FAILURE_REPLACEMENT_TEXT, OAuth2Service.getOAuth2ErrorMessage(errorCode));
    }

    public String getInternalFailureHtml(String errorMessage) {
        return this.failureHtml.replace(FAILURE_REPLACEMENT_TEXT, Strings.nullToEmpty((String)errorMessage));
    }

    private static byte[] secureRandomBytes(int count) {
        byte[] bytes = new byte[count];
        SECURE_RANDOM.nextBytes(bytes);
        return bytes;
    }

    private static String getOAuth2ErrorMessage(String errorCode) {
        switch (errorCode) {
            case "access_denied": {
                return "OAuth2 server denied the login";
            }
            case "unauthorized_client": {
                return "OAuth2 server does not allow request from this Trino server";
            }
            case "server_error": {
                return "OAuth2 server had a failure";
            }
            case "temporarily_unavailable": {
                return "OAuth2 server is temporarily unavailable";
            }
        }
        return "OAuth2 unknown error code: " + errorCode;
    }

    private static String randomNonce() {
        return BaseEncoding.base64Url().encode(OAuth2Service.secureRandomBytes(18));
    }

    @VisibleForTesting
    public static String hashNonce(String nonce) {
        return Hashing.sha256().hashString((CharSequence)nonce, StandardCharsets.UTF_8).toString();
    }
}

