/*
 * Decompiled with CFR 0.152.
 */
package fish.payara.security.openid.controller;

import fish.payara.security.openid.controller.CacheKey;
import fish.payara.security.openid.domain.OpenIdConfiguration;
import fish.payara.security.shaded.nimbusds.jose.Algorithm;
import fish.payara.security.shaded.nimbusds.jose.EncryptionMethod;
import fish.payara.security.shaded.nimbusds.jose.JOSEException;
import fish.payara.security.shaded.nimbusds.jose.JWEAlgorithm;
import fish.payara.security.shaded.nimbusds.jose.JWEHeader;
import fish.payara.security.shaded.nimbusds.jose.JWSAlgorithm;
import fish.payara.security.shaded.nimbusds.jose.JWSHeader;
import fish.payara.security.shaded.nimbusds.jose.jwk.source.ImmutableSecret;
import fish.payara.security.shaded.nimbusds.jose.jwk.source.JWKSource;
import fish.payara.security.shaded.nimbusds.jose.jwk.source.RemoteJWKSet;
import fish.payara.security.shaded.nimbusds.jose.proc.BadJOSEException;
import fish.payara.security.shaded.nimbusds.jose.proc.JWEDecryptionKeySelector;
import fish.payara.security.shaded.nimbusds.jose.proc.JWEKeySelector;
import fish.payara.security.shaded.nimbusds.jose.proc.JWSKeySelector;
import fish.payara.security.shaded.nimbusds.jose.proc.JWSVerificationKeySelector;
import fish.payara.security.shaded.nimbusds.jose.util.DefaultResourceRetriever;
import fish.payara.security.shaded.nimbusds.jwt.EncryptedJWT;
import fish.payara.security.shaded.nimbusds.jwt.JWT;
import fish.payara.security.shaded.nimbusds.jwt.JWTClaimsSet;
import fish.payara.security.shaded.nimbusds.jwt.PlainJWT;
import fish.payara.security.shaded.nimbusds.jwt.SignedJWT;
import fish.payara.security.shaded.nimbusds.jwt.proc.DefaultJWTProcessor;
import fish.payara.security.shaded.nimbusds.jwt.proc.JWTClaimsSetVerifier;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import java.nio.charset.StandardCharsets;
import java.text.ParseException;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

@ApplicationScoped
public class JWTValidator {
    @Inject
    OpenIdConfiguration configuration;
    private ConcurrentHashMap<CacheKey, JWSKeySelector> jwsCache = new ConcurrentHashMap();
    private ConcurrentHashMap<CacheKey, JWEKeySelector> jweCache = new ConcurrentHashMap();

    public JWTClaimsSet validateBearerToken(JWT token, JWTClaimsSetVerifier jwtVerifier) {
        JWTClaimsSet claimsSet;
        block6: {
            try {
                if (token instanceof PlainJWT) {
                    PlainJWT plainToken = (PlainJWT)token;
                    claimsSet = plainToken.getJWTClaimsSet();
                    jwtVerifier.verify(claimsSet, null);
                    break block6;
                }
                if (token instanceof SignedJWT) {
                    SignedJWT signedToken = (SignedJWT)token;
                    JWSHeader header = signedToken.getHeader();
                    String alg = header.getAlgorithm().getName();
                    if (Objects.isNull(alg)) {
                        alg = "RS256";
                    }
                    DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor();
                    jwtProcessor.setJWSKeySelector(this.getJWSKeySelector(alg));
                    jwtProcessor.setJWTClaimsSetVerifier(jwtVerifier);
                    claimsSet = jwtProcessor.process(signedToken, null);
                    break block6;
                }
                if (token instanceof EncryptedJWT) {
                    EncryptedJWT encryptedToken = (EncryptedJWT)token;
                    JWEHeader header = encryptedToken.getHeader();
                    String alg = header.getAlgorithm().getName();
                    DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor();
                    jwtProcessor.setJWSKeySelector(this.getJWSKeySelector(alg));
                    jwtProcessor.setJWEKeySelector(this.getJWEKeySelector());
                    jwtProcessor.setJWTClaimsSetVerifier(jwtVerifier);
                    claimsSet = jwtProcessor.process(encryptedToken, null);
                    break block6;
                }
                throw new IllegalStateException("Unexpected JWT type : " + token.getClass());
            }
            catch (JOSEException | BadJOSEException | ParseException ex) {
                throw new IllegalStateException(ex);
            }
        }
        return claimsSet;
    }

    private JWSKeySelector<?> getJWSKeySelector(String alg) {
        return this.jwsCache.computeIfAbsent(this.createCacheKey(alg), k -> this.createJWSKeySelector(alg));
    }

    private CacheKey createCacheKey(String alg) {
        return new CacheKey(alg, this.configuration.getEncryptionMetadata().getEncryptionAlgorithm(), this.configuration.getEncryptionMetadata().getEncryptionMethod(), this.configuration.getEncryptionMetadata().getPrivateKeySource(), this.configuration.getJwksConnectTimeout(), this.configuration.getJwksReadTimeout(), this.configuration.getProviderMetadata().getJwksURL(), this.configuration.getClientSecret());
    }

    private JWEKeySelector<?> getJWEKeySelector() {
        return this.jweCache.computeIfAbsent(this.createCacheKey(null), k -> this.createJweKeySelector());
    }

    private JWEKeySelector<?> createJweKeySelector() {
        JWEAlgorithm jwsAlg = this.configuration.getEncryptionMetadata().getEncryptionAlgorithm();
        EncryptionMethod jweEnc = this.configuration.getEncryptionMetadata().getEncryptionMethod();
        JWKSource jwkSource = this.configuration.getEncryptionMetadata().getPrivateKeySource();
        if (Objects.isNull(jwsAlg)) {
            throw new IllegalStateException("Missing JWE encryption algorithm ");
        }
        if (!this.configuration.getProviderMetadata().getIdTokenEncryptionAlgValuesSupported().contains(jwsAlg.getName())) {
            throw new IllegalStateException("Unsupported ID tokens algorithm :" + jwsAlg.getName());
        }
        if (Objects.isNull(jweEnc)) {
            throw new IllegalStateException("Missing JWE encryption method");
        }
        if (!this.configuration.getProviderMetadata().getIdTokenEncryptionEncValuesSupported().contains(jweEnc.getName())) {
            throw new IllegalStateException("Unsupported ID tokens encryption method :" + jweEnc.getName());
        }
        return new JWEDecryptionKeySelector(jwsAlg, jweEnc, jwkSource);
    }

    private JWSKeySelector<?> createJWSKeySelector(String alg) {
        JWKSource jwkSource;
        JWSAlgorithm jWSAlgorithm = new JWSAlgorithm(alg);
        if (Algorithm.NONE.equals(jWSAlgorithm)) {
            throw new IllegalStateException("Unsupported JWS algorithm : " + jWSAlgorithm);
        }
        if (JWSAlgorithm.Family.RSA.contains(jWSAlgorithm) || JWSAlgorithm.Family.EC.contains(jWSAlgorithm)) {
            DefaultResourceRetriever jwkSetRetriever = new DefaultResourceRetriever(this.configuration.getJwksConnectTimeout(), this.configuration.getJwksReadTimeout(), 51200);
            jwkSource = new RemoteJWKSet(this.configuration.getProviderMetadata().getJwksURL(), jwkSetRetriever);
        } else if (JWSAlgorithm.Family.HMAC_SHA.contains(jWSAlgorithm)) {
            byte[] clientSecret = new String(this.configuration.getClientSecret()).getBytes(StandardCharsets.UTF_8);
            if (Objects.isNull(clientSecret)) {
                throw new IllegalStateException("Missing client secret");
            }
            jwkSource = new ImmutableSecret(clientSecret);
        } else {
            throw new IllegalStateException("Unsupported JWS algorithm : " + jWSAlgorithm);
        }
        return new JWSVerificationKeySelector(jWSAlgorithm, jwkSource);
    }
}

