/*
 * Decompiled with CFR 0.152.
 */
package io.trino.server.security.oauth2;

import com.google.common.base.Strings;
import com.google.common.base.Verify;
import com.google.common.collect.Ordering;
import com.google.common.hash.Hashing;
import com.google.common.io.BaseEncoding;
import com.google.common.io.Resources;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.SigningKeyResolver;
import io.trino.server.security.oauth2.ChallengeFailedException;
import io.trino.server.security.oauth2.ForOAuth2;
import io.trino.server.security.oauth2.OAuth2Client;
import io.trino.server.security.oauth2.OAuth2Config;
import java.io.IOException;
import java.net.URI;
import java.net.URL;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.TemporalAmount;
import java.util.Date;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import javax.inject.Inject;

public class OAuth2Service {
    public static final String REDIRECT_URI = "redirect_uri";
    public static final String STATE = "state";
    public static final String NONCE = "nonce";
    public static final String OPENID_SCOPE = "openid";
    private static final String STATE_AUDIENCE_UI = "trino_oauth_ui";
    private static final String STATE_AUDIENCE_REST = "trino_oauth_rest";
    private static final String FAILURE_REPLACEMENT_TEXT = "<!-- ERROR_MESSAGE -->";
    private static final Random SECURE_RANDOM = new SecureRandom();
    private final OAuth2Client client;
    private final SigningKeyResolver signingKeyResolver;
    private final String successHtml;
    private final String failureHtml;
    private final Set<String> scopes;
    private final TemporalAmount challengeTimeout;
    private final byte[] stateHmac;

    @Inject
    public OAuth2Service(OAuth2Client client, @ForOAuth2 SigningKeyResolver signingKeyResolver, OAuth2Config oauth2Config) throws IOException {
        this.client = Objects.requireNonNull(client, "client is null");
        this.signingKeyResolver = Objects.requireNonNull(signingKeyResolver, "signingKeyResolver is null");
        this.successHtml = Resources.toString((URL)Resources.getResource(this.getClass(), (String)"/oauth2/success.html"), (Charset)StandardCharsets.UTF_8);
        this.failureHtml = Resources.toString((URL)Resources.getResource(this.getClass(), (String)"/oauth2/failure.html"), (Charset)StandardCharsets.UTF_8);
        Verify.verify((boolean)this.failureHtml.contains(FAILURE_REPLACEMENT_TEXT), (String)"login.html does not contain the replacement text", (Object[])new Object[0]);
        Objects.requireNonNull(oauth2Config, "oauth2Config is null");
        this.scopes = oauth2Config.getScopes();
        this.challengeTimeout = Duration.ofMillis(oauth2Config.getChallengeTimeout().toMillis());
        this.stateHmac = oauth2Config.getStateKey().map(key -> Hashing.sha256().hashString((CharSequence)key, StandardCharsets.UTF_8).asBytes()).orElseGet(() -> OAuth2Service.secureRandomBytes(32));
    }

    public OAuthChallenge startWebUiChallenge(URI callbackUri) {
        Instant challengeExpiration = Instant.now().plus(this.challengeTimeout);
        String state = Jwts.builder().signWith(SignatureAlgorithm.HS256, this.stateHmac).setAudience(STATE_AUDIENCE_UI).setExpiration(Date.from(challengeExpiration)).compact();
        Optional<Object> nonce = this.scopes.contains(OPENID_SCOPE) ? Optional.of(OAuth2Service.randomNonce()) : Optional.empty();
        return new OAuthChallenge(this.client.getAuthorizationUri(state, callbackUri, nonce.map(OAuth2Service::hashNonce)), challengeExpiration, nonce);
    }

    public URI startRestChallenge(URI callbackUri, UUID authId) {
        String state = Jwts.builder().signWith(SignatureAlgorithm.HS256, this.stateHmac).setId(authId.toString()).setAudience(STATE_AUDIENCE_REST).setExpiration(Date.from(Instant.now().plus(this.challengeTimeout))).compact();
        return this.client.getAuthorizationUri(state, callbackUri, Optional.empty());
    }

    public OAuthResult finishChallenge(Optional<UUID> authId, String code, URI callbackUri, Optional<String> nonce) throws ChallengeFailedException {
        Objects.requireNonNull(callbackUri, "callbackUri is null");
        Objects.requireNonNull(authId, "authId is null");
        Objects.requireNonNull(code, "code is null");
        OAuth2Client.AccessToken accessToken = this.client.getAccessToken(code, callbackUri);
        Claims parsedToken = (Claims)Jwts.parser().setSigningKeyResolver(this.signingKeyResolver).parseClaimsJws(accessToken.getAccessToken()).getBody();
        this.validateNonce(authId, accessToken, nonce);
        Instant validUntil = accessToken.getValidUntil().map(instant -> (Instant)Ordering.natural().min(instant, (Object)parsedToken.getExpiration().toInstant())).orElse(parsedToken.getExpiration().toInstant());
        return new OAuthResult(authId, accessToken.getAccessToken(), validUntil);
    }

    public Optional<UUID> getAuthId(String state) throws ChallengeFailedException {
        Claims stateClaims = this.parseState(state);
        if (STATE_AUDIENCE_UI.equals(stateClaims.getAudience())) {
            return Optional.empty();
        }
        if (STATE_AUDIENCE_REST.equals(stateClaims.getAudience())) {
            try {
                return Optional.of(UUID.fromString(stateClaims.getId()));
            }
            catch (IllegalArgumentException e) {
                throw new ChallengeFailedException("State is does not contain an auth ID");
            }
        }
        throw new ChallengeFailedException("Unexpected state audience");
    }

    private Claims parseState(String state) throws ChallengeFailedException {
        try {
            return (Claims)Jwts.parser().setSigningKey(this.stateHmac).parseClaimsJws(state).getBody();
        }
        catch (RuntimeException e) {
            throw new ChallengeFailedException("State validation failed", e);
        }
    }

    public Jws<Claims> parseClaimsJws(String token) {
        return Jwts.parser().setSigningKeyResolver(this.signingKeyResolver).parseClaimsJws(token);
    }

    public String getSuccessHtml() {
        return this.successHtml;
    }

    public String getCallbackErrorHtml(String errorCode) {
        return this.failureHtml.replace(FAILURE_REPLACEMENT_TEXT, OAuth2Service.getOAuth2ErrorMessage(errorCode));
    }

    public String getInternalFailureHtml(String errorMessage) {
        return this.failureHtml.replace(FAILURE_REPLACEMENT_TEXT, Strings.nullToEmpty((String)errorMessage));
    }

    private void validateNonce(Optional<UUID> authId, OAuth2Client.AccessToken accessToken, Optional<String> nonce) throws ChallengeFailedException {
        if (authId.isPresent()) {
            return;
        }
        if (nonce.isPresent() != accessToken.getIdToken().isPresent()) {
            throw new ChallengeFailedException("Cannot validate nonce parameter");
        }
        nonce.ifPresent(n -> Jwts.parser().setSigningKeyResolver(this.signingKeyResolver).require(NONCE, (Object)OAuth2Service.hashNonce(n)).parseClaimsJws(accessToken.getIdToken().get()));
    }

    private static byte[] secureRandomBytes(int count) {
        byte[] bytes = new byte[count];
        SECURE_RANDOM.nextBytes(bytes);
        return bytes;
    }

    private static String getOAuth2ErrorMessage(String errorCode) {
        switch (errorCode) {
            case "access_denied": {
                return "OAuth2 server denied the login";
            }
            case "unauthorized_client": {
                return "OAuth2 server does not allow request from this Trino server";
            }
            case "server_error": {
                return "OAuth2 server had a failure";
            }
            case "temporarily_unavailable": {
                return "OAuth2 server is temporarily unavailable";
            }
        }
        return "OAuth2 unknown error code: " + errorCode;
    }

    private static String randomNonce() {
        return BaseEncoding.base64Url().encode(OAuth2Service.secureRandomBytes(18));
    }

    private static String hashNonce(String nonce) {
        return Hashing.sha256().hashString((CharSequence)nonce, StandardCharsets.UTF_8).toString();
    }

    public static class OAuthResult {
        private final Optional<UUID> authId;
        private final String accessToken;
        private final Instant tokenExpiration;

        public OAuthResult(Optional<UUID> authId, String accessToken, Instant tokenExpiration) {
            this.authId = Objects.requireNonNull(authId, "authId is null");
            this.accessToken = Objects.requireNonNull(accessToken, "accessToken is null");
            this.tokenExpiration = Objects.requireNonNull(tokenExpiration, "tokenExpiration is null");
        }

        public Optional<UUID> getAuthId() {
            return this.authId;
        }

        public String getAccessToken() {
            return this.accessToken;
        }

        public Instant getTokenExpiration() {
            return this.tokenExpiration;
        }
    }

    public static class OAuthChallenge {
        private final URI redirectUrl;
        private final Instant challengeExpiration;
        private final Optional<String> nonce;

        public OAuthChallenge(URI redirectUrl, Instant challengeExpiration, Optional<String> nonce) {
            this.redirectUrl = Objects.requireNonNull(redirectUrl, "redirectUrl is null");
            this.challengeExpiration = Objects.requireNonNull(challengeExpiration, "challengeExpiration is null");
            this.nonce = Objects.requireNonNull(nonce, "nonce is null");
        }

        public URI getRedirectUrl() {
            return this.redirectUrl;
        }

        public Instant getChallengeExpiration() {
            return this.challengeExpiration;
        }

        public Optional<String> getNonce() {
            return this.nonce;
        }
    }
}

