/*
 * Decompiled with CFR 0.152.
 */
package io.trino.server.security.oauth2;

import com.nimbusds.jose.EncryptionMethod;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWEAlgorithm;
import com.nimbusds.jose.JWEDecrypter;
import com.nimbusds.jose.JWEEncrypter;
import com.nimbusds.jose.JWEHeader;
import com.nimbusds.jose.JWEObject;
import com.nimbusds.jose.KeyLengthException;
import com.nimbusds.jose.Payload;
import com.nimbusds.jose.crypto.AESDecrypter;
import com.nimbusds.jose.crypto.AESEncrypter;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwtBuilder;
import io.jsonwebtoken.JwtParser;
import io.jsonwebtoken.JwtParserBuilder;
import io.jsonwebtoken.io.CompressionAlgorithm;
import io.jsonwebtoken.lang.NestedCollection;
import io.trino.server.security.jwt.JwtUtil;
import io.trino.server.security.oauth2.OAuth2Client;
import io.trino.server.security.oauth2.RefreshTokensConfig;
import io.trino.server.security.oauth2.TokenPairSerializer;
import io.trino.server.security.oauth2.ZstdCodec;
import java.security.NoSuchAlgorithmException;
import java.text.ParseException;
import java.time.Clock;
import java.util.Date;
import java.util.Map;
import java.util.Objects;
import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;

public class JweTokenSerializer
implements TokenPairSerializer {
    private static final CompressionAlgorithm COMPRESSION_ALGORITHM = new ZstdCodec();
    private static final Logger LOG = Logger.get(JweTokenSerializer.class);
    private static final String ACCESS_TOKEN_KEY = "access_token";
    private static final String EXPIRATION_TIME_KEY = "expiration_time";
    private static final String REFRESH_TOKEN_KEY = "refresh_token";
    private final JweEncryptedSerializer jweSerializer;
    private final OAuth2Client client;
    private final Clock clock;
    private final String issuer;
    private final String audience;
    private final Duration tokenExpiration;
    private final JwtParser parser;
    private final String principalField;

    public JweTokenSerializer(RefreshTokensConfig config, OAuth2Client client, String issuer, String audience, String principalField, Clock clock, Duration tokenExpiration) {
        this.jweSerializer = new JweEncryptedSerializer(JweTokenSerializer.getOrGenerateKey(config));
        this.client = Objects.requireNonNull(client, "client is null");
        this.issuer = Objects.requireNonNull(issuer, "issuer is null");
        this.principalField = Objects.requireNonNull(principalField, "principalField is null");
        this.audience = Objects.requireNonNull(audience, "issuer is null");
        this.clock = Objects.requireNonNull(clock, "clock is null");
        this.tokenExpiration = Objects.requireNonNull(tokenExpiration, "tokenExpiration is null");
        this.parser = ((JwtParserBuilder)((NestedCollection)JwtUtil.newJwtParserBuilder().clock(() -> Date.from(clock.instant())).requireIssuer(this.issuer).requireAudience(this.audience).zip().add((Object)COMPRESSION_ALGORITHM)).and()).unsecuredDecompression().unsecured().build();
    }

    @Override
    public TokenPairSerializer.TokenPair deserialize(String token) {
        Objects.requireNonNull(token, "token is null");
        try {
            Claims claims = (Claims)this.parser.parseUnsecuredClaims((CharSequence)this.jweSerializer.deserialize(token)).getBody();
            return TokenPairSerializer.TokenPair.withAccessAndRefreshTokens((String)claims.get(ACCESS_TOKEN_KEY, String.class), (Date)claims.get(EXPIRATION_TIME_KEY, Date.class), (String)claims.get(REFRESH_TOKEN_KEY, String.class));
        }
        catch (ParseException ex) {
            return TokenPairSerializer.TokenPair.withAccessToken(token);
        }
    }

    @Override
    public String serialize(TokenPairSerializer.TokenPair tokenPair) {
        Objects.requireNonNull(tokenPair, "tokenPair is null");
        Map<String, Object> claims = this.client.getClaims(tokenPair.accessToken()).orElseThrow(() -> new IllegalArgumentException("Claims are missing"));
        if (!claims.containsKey(this.principalField)) {
            throw new IllegalArgumentException(String.format("%s field is missing", this.principalField));
        }
        JwtBuilder jwt = ((JwtBuilder)((NestedCollection)JwtUtil.newJwtBuilder().expiration(Date.from(this.clock.instant().plusMillis(this.tokenExpiration.toMillis()))).claim(this.principalField, (Object)claims.get(this.principalField).toString()).audience().add((Object)this.audience)).and()).issuer(this.issuer).claim(ACCESS_TOKEN_KEY, (Object)tokenPair.accessToken()).claim(EXPIRATION_TIME_KEY, (Object)tokenPair.expiration()).compressWith(COMPRESSION_ALGORITHM);
        if (tokenPair.refreshToken().isPresent()) {
            jwt.claim(REFRESH_TOKEN_KEY, (Object)tokenPair.refreshToken().orElseThrow());
        } else {
            LOG.info("No refresh token has been issued, although coordinator expects one. Please check your IdP whether that is correct behaviour");
        }
        return this.jweSerializer.serialize(jwt.compact());
    }

    private static SecretKey getOrGenerateKey(RefreshTokensConfig config) {
        SecretKey signingKey = config.getSecretKey();
        if (signingKey == null) {
            try {
                KeyGenerator generator = KeyGenerator.getInstance("AES");
                generator.init(256);
                return generator.generateKey();
            }
            catch (NoSuchAlgorithmException e) {
                throw new RuntimeException(e);
            }
        }
        return signingKey;
    }

    private static class JweEncryptedSerializer {
        private final AESEncrypter jweEncrypter;
        private final AESDecrypter jweDecrypter;
        private final JWEHeader encryptionHeader;

        private JweEncryptedSerializer(SecretKey secretKey) {
            try {
                this.encryptionHeader = this.createEncryptionHeader(secretKey);
                this.jweEncrypter = new AESEncrypter(secretKey);
                this.jweDecrypter = new AESDecrypter(secretKey);
            }
            catch (KeyLengthException e) {
                throw new RuntimeException(e);
            }
        }

        private JWEHeader createEncryptionHeader(SecretKey key) {
            int keyLength = key.getEncoded().length;
            return switch (keyLength) {
                case 16 -> new JWEHeader(JWEAlgorithm.A128GCMKW, EncryptionMethod.A128GCM);
                case 24 -> new JWEHeader(JWEAlgorithm.A192GCMKW, EncryptionMethod.A192GCM);
                case 32 -> new JWEHeader(JWEAlgorithm.A256GCMKW, EncryptionMethod.A256GCM);
                default -> throw new IllegalArgumentException("Secret key size must be either 16, 24 or 32 bytes but was %d".formatted(keyLength));
            };
        }

        private String serialize(String payload) {
            try {
                JWEObject jwe = new JWEObject(this.encryptionHeader, new Payload(payload));
                jwe.encrypt((JWEEncrypter)this.jweEncrypter);
                return jwe.serialize();
            }
            catch (JOSEException e) {
                throw new RuntimeException(e);
            }
        }

        private String deserialize(String token) throws ParseException {
            try {
                JWEObject jwe = JWEObject.parse((String)token);
                jwe.decrypt((JWEDecrypter)this.jweDecrypter);
                return jwe.getPayload().toString();
            }
            catch (JOSEException e) {
                throw new RuntimeException(e);
            }
        }
    }
}

