/*
 * Decompiled with CFR 0.152.
 */
package io.trino.server.security.oauth2;

import com.google.common.base.Preconditions;
import com.nimbusds.jose.EncryptionMethod;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWEAlgorithm;
import com.nimbusds.jose.JWEDecrypter;
import com.nimbusds.jose.JWEEncrypter;
import com.nimbusds.jose.JWEHeader;
import com.nimbusds.jose.JWEObject;
import com.nimbusds.jose.KeyLengthException;
import com.nimbusds.jose.Payload;
import com.nimbusds.jose.crypto.AESDecrypter;
import com.nimbusds.jose.crypto.AESEncrypter;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.CompressionCodec;
import io.jsonwebtoken.CompressionException;
import io.jsonwebtoken.Header;
import io.jsonwebtoken.JwtBuilder;
import io.jsonwebtoken.JwtParser;
import io.trino.server.security.jwt.JwtUtil;
import io.trino.server.security.oauth2.OAuth2Client;
import io.trino.server.security.oauth2.RefreshTokensConfig;
import io.trino.server.security.oauth2.TokenPairSerializer;
import io.trino.server.security.oauth2.ZstdCodec;
import java.security.NoSuchAlgorithmException;
import java.text.ParseException;
import java.time.Clock;
import java.util.Date;
import java.util.Map;
import java.util.Objects;
import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;

public class JweTokenSerializer
implements TokenPairSerializer {
    private static final Logger LOG = Logger.get(JweTokenSerializer.class);
    private static final JWEAlgorithm ALGORITHM = JWEAlgorithm.A256KW;
    private static final EncryptionMethod ENCRYPTION_METHOD = EncryptionMethod.A256CBC_HS512;
    private static final CompressionCodec COMPRESSION_CODEC = new ZstdCodec();
    private static final String ACCESS_TOKEN_KEY = "access_token";
    private static final String EXPIRATION_TIME_KEY = "expiration_time";
    private static final String REFRESH_TOKEN_KEY = "refresh_token";
    private final OAuth2Client client;
    private final Clock clock;
    private final String issuer;
    private final String audience;
    private final Duration tokenExpiration;
    private final JwtParser parser;
    private final AESEncrypter jweEncrypter;
    private final AESDecrypter jweDecrypter;
    private final String principalField;

    public JweTokenSerializer(RefreshTokensConfig config, OAuth2Client client, String issuer, String audience, String principalField, Clock clock, Duration tokenExpiration) throws KeyLengthException, NoSuchAlgorithmException {
        SecretKey secretKey = JweTokenSerializer.createKey(config);
        this.jweEncrypter = new AESEncrypter(secretKey);
        this.jweDecrypter = new AESDecrypter(secretKey);
        this.client = Objects.requireNonNull(client, "client is null");
        this.issuer = Objects.requireNonNull(issuer, "issuer is null");
        this.principalField = Objects.requireNonNull(principalField, "principalField is null");
        this.audience = Objects.requireNonNull(audience, "issuer is null");
        this.clock = Objects.requireNonNull(clock, "clock is null");
        this.tokenExpiration = Objects.requireNonNull(tokenExpiration, "tokenExpiration is null");
        this.parser = JwtUtil.newJwtParserBuilder().setClock(() -> Date.from(clock.instant())).requireIssuer(this.issuer).requireAudience(this.audience).setCompressionCodecResolver(JweTokenSerializer::resolveCompressionCodec).build();
    }

    @Override
    public TokenPairSerializer.TokenPair deserialize(String token) {
        Objects.requireNonNull(token, "token is null");
        try {
            JWEObject jwe = JWEObject.parse((String)token);
            jwe.decrypt((JWEDecrypter)this.jweDecrypter);
            Claims claims = (Claims)this.parser.parseClaimsJwt(jwe.getPayload().toString()).getBody();
            return TokenPairSerializer.TokenPair.accessAndRefreshTokens((String)claims.get(ACCESS_TOKEN_KEY, String.class), (Date)claims.get(EXPIRATION_TIME_KEY, Date.class), (String)claims.get(REFRESH_TOKEN_KEY, String.class));
        }
        catch (ParseException ex) {
            return TokenPairSerializer.TokenPair.accessToken(token);
        }
        catch (JOSEException ex) {
            throw new IllegalArgumentException("Decryption failed", ex);
        }
    }

    @Override
    public String serialize(TokenPairSerializer.TokenPair tokenPair) {
        Objects.requireNonNull(tokenPair, "tokenPair is null");
        Map<String, Object> claims = this.client.getClaims(tokenPair.getAccessToken()).orElseThrow(() -> new IllegalArgumentException("Claims are missing"));
        if (!claims.containsKey(this.principalField)) {
            throw new IllegalArgumentException(String.format("%s field is missing", this.principalField));
        }
        JwtBuilder jwt = JwtUtil.newJwtBuilder().setExpiration(Date.from(this.clock.instant().plusMillis(this.tokenExpiration.toMillis()))).claim(this.principalField, (Object)claims.get(this.principalField).toString()).setAudience(this.audience).setIssuer(this.issuer).claim(ACCESS_TOKEN_KEY, (Object)tokenPair.getAccessToken()).claim(EXPIRATION_TIME_KEY, (Object)tokenPair.getExpiration()).compressWith(COMPRESSION_CODEC);
        if (tokenPair.getRefreshToken().isPresent()) {
            jwt.claim(REFRESH_TOKEN_KEY, (Object)tokenPair.getRefreshToken().orElseThrow());
        } else {
            LOG.info("No refresh token has been issued, although coordinator expects one. Please check your IdP whether that is correct behaviour");
        }
        try {
            JWEObject jwe = new JWEObject(new JWEHeader(ALGORITHM, ENCRYPTION_METHOD), new Payload(jwt.compact()));
            jwe.encrypt((JWEEncrypter)this.jweEncrypter);
            return jwe.serialize();
        }
        catch (JOSEException ex) {
            throw new IllegalStateException("Encryption failed", ex);
        }
    }

    private static SecretKey createKey(RefreshTokensConfig config) throws NoSuchAlgorithmException {
        SecretKey signingKey = config.getSecretKey();
        if (signingKey == null) {
            KeyGenerator generator = KeyGenerator.getInstance("AES");
            generator.init(256);
            return generator.generateKey();
        }
        return signingKey;
    }

    private static CompressionCodec resolveCompressionCodec(Header header) throws CompressionException {
        if (header.getCompressionAlgorithm() != null) {
            Preconditions.checkState((boolean)header.getCompressionAlgorithm().equals("ZSTD"), (String)"Unknown codec '%s' used for token compression", (Object)header.getCompressionAlgorithm());
            return COMPRESSION_CODEC;
        }
        return null;
    }
}

