/*
 * Decompiled with CFR 0.152.
 */
package uk.gov.di.ipv.cri.common.library.service;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.crypto.ECDSAVerifier;
import com.nimbusds.jose.crypto.RSASSAVerifier;
import com.nimbusds.jose.crypto.impl.ECDSA;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
import com.nimbusds.oauth2.sdk.id.ClientID;
import java.io.ByteArrayInputStream;
import java.security.PublicKey;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.gov.di.ipv.cri.common.library.exception.ClientConfigurationException;
import uk.gov.di.ipv.cri.common.library.exception.SessionValidationException;
import uk.gov.di.ipv.cri.common.library.util.JwkKeyCache;

public class JWTVerifier {
    private static final Logger LOGGER = LoggerFactory.getLogger(JWTVerifier.class);
    private final JwkKeyCache jwkKeyCache = new JwkKeyCache();

    public void verifyAuthorizationJWT(Map<String, String> clientAuthenticationConfig, SignedJWT signedJWT) throws SessionValidationException, ClientConfigurationException {
        this.verifyJWT(clientAuthenticationConfig, signedJWT, Set.of("exp", "sub", "nbf"), new JWTClaimsSet.Builder().issuer(clientAuthenticationConfig.get("issuer")).audience(clientAuthenticationConfig.get("audience")).build());
    }

    public void verifyAccessTokenJWT(Map<String, String> clientAuthenticationConfig, SignedJWT signedJWT, ClientID clientID) throws SessionValidationException, ClientConfigurationException {
        Set<String> requiredClaims = Set.of("exp", "sub", "iss", "aud", "jti");
        JWTClaimsSet expectedClaimValues = new JWTClaimsSet.Builder().issuer(clientID.getValue()).subject(clientID.getValue()).audience(clientAuthenticationConfig.get("audience")).build();
        this.verifyJWT(clientAuthenticationConfig, signedJWT, requiredClaims, expectedClaimValues);
    }

    public void validateMaxAllowedJarTtl(Instant jwtExpirationTime, long maxAllowedTtl) throws SessionValidationException {
        LocalDateTime maximumExpirationTime = LocalDateTime.ofInstant(Instant.now().plus(maxAllowedTtl, ChronoUnit.SECONDS), ZoneOffset.UTC);
        LocalDateTime expirationTime = LocalDateTime.ofInstant(jwtExpirationTime, ZoneOffset.UTC);
        if (expirationTime.isAfter(maximumExpirationTime)) {
            throw new SessionValidationException("The client JWT expiry date has surpassed the maximum allowed ttl value");
        }
    }

    private void verifyJWT(Map<String, String> clientAuthenticationConfig, SignedJWT signedJWT, Set<String> requiredClaims, JWTClaimsSet expectedClaimValues) throws SessionValidationException, ClientConfigurationException {
        this.verifyJWTHeader(clientAuthenticationConfig, signedJWT);
        this.verifyJWTClaimsSet(signedJWT, requiredClaims, expectedClaimValues);
        this.verifyJWTSignature(clientAuthenticationConfig, signedJWT);
    }

    private void verifyJWTHeader(Map<String, String> clientAuthenticationConfig, SignedJWT signedJWT) throws SessionValidationException {
        JWSAlgorithm configuredAlgorithm = JWSAlgorithm.parse((String)clientAuthenticationConfig.get("authenticationAlg"));
        JWSAlgorithm jwtAlgorithm = signedJWT.getHeader().getAlgorithm();
        if (jwtAlgorithm != configuredAlgorithm) {
            throw new SessionValidationException(String.format("jwt signing algorithm %s does not match signing algorithm configured for client: %s", jwtAlgorithm, configuredAlgorithm));
        }
    }

    private void verifyJWTSignature(Map<String, String> clientAuthenticationConfig, SignedJWT signedJWT) throws SessionValidationException, ClientConfigurationException {
        String publicCertificateToVerify = clientAuthenticationConfig.get("publicSigningJwkBase64");
        try {
            SignedJWT concatSignatureJwt = this.signatureIsDerFormat(signedJWT) ? this.transcodeSignature(signedJWT) : signedJWT;
            this.jwkKeyCache.getBase64JwkForKid(clientAuthenticationConfig.get("jwksEndpoint"), signedJWT.getHeader().getKeyID()).ifPresentOrElse(value -> clientAuthenticationConfig.replace("publicSigningJwkBase64", (String)value), () -> LOGGER.warn("{} not found in public JWK response", (Object)signedJWT.getHeader().getKeyID()));
            JWSAlgorithm signingAlgorithm = signedJWT.getHeader().getAlgorithm();
            PublicKey pubicKeyFromConfig = this.getPublicKeyFromConfig(publicCertificateToVerify, signingAlgorithm);
            if (!this.verifySignature(concatSignatureJwt, pubicKeyFromConfig)) {
                throw new SessionValidationException("JWT signature verification failed");
            }
        }
        catch (JOSEException | ParseException e) {
            throw new SessionValidationException("JWT signature verification failed", (Exception)e);
        }
        catch (CertificateException e) {
            throw new ClientConfigurationException("Certificate problem encountered", e);
        }
    }

    private boolean signatureIsDerFormat(SignedJWT signedJWT) throws JOSEException {
        return signedJWT.getSignature().decode().length != ECDSA.getSignatureByteArrayLength((JWSAlgorithm)JWSAlgorithm.ES256);
    }

    private SignedJWT transcodeSignature(SignedJWT signedJWT) throws JOSEException, ParseException {
        Base64URL transcodedSignatureBase64 = Base64URL.encode((byte[])ECDSA.transcodeSignatureToConcat((byte[])signedJWT.getSignature().decode(), (int)ECDSA.getSignatureByteArrayLength((JWSAlgorithm)JWSAlgorithm.ES256)));
        String[] jwtParts = signedJWT.serialize().split("\\.");
        return SignedJWT.parse((String)String.format("%s.%s.%s", jwtParts[0], jwtParts[1], transcodedSignatureBase64));
    }

    private void verifyJWTClaimsSet(SignedJWT signedJWT, Set<String> requiredClaims, JWTClaimsSet expectedClaimValues) throws SessionValidationException {
        try {
            new DefaultJWTClaimsVerifier(expectedClaimValues, requiredClaims).verify(signedJWT.getJWTClaimsSet(), null);
        }
        catch (BadJWTException | ParseException e) {
            throw new SessionValidationException(e.getMessage(), (Exception)e);
        }
    }

    private PublicKey getPublicKeyFromConfig(String serialisedPublicKey, JWSAlgorithm signingAlgorithm) throws CertificateException, ParseException, JOSEException {
        if (JWSAlgorithm.Family.RSA.contains((Object)signingAlgorithm)) {
            byte[] binaryCertificate = Base64.getDecoder().decode(serialisedPublicKey);
            CertificateFactory factory = CertificateFactory.getInstance("X.509");
            Certificate certificate = factory.generateCertificate(new ByteArrayInputStream(binaryCertificate));
            return certificate.getPublicKey();
        }
        if (JWSAlgorithm.Family.EC.contains((Object)signingAlgorithm)) {
            return ECKey.parse((String)new String(Base64.getDecoder().decode(serialisedPublicKey))).toECPublicKey();
        }
        throw new IllegalArgumentException("Unexpected signing algorithm encountered: " + signingAlgorithm.getName());
    }

    private boolean verifySignature(SignedJWT signedJWT, PublicKey clientPublicKey) throws JOSEException, ClientConfigurationException {
        if (clientPublicKey instanceof RSAPublicKey) {
            RSASSAVerifier rsassaVerifier = new RSASSAVerifier((RSAPublicKey)clientPublicKey);
            return signedJWT.verify((JWSVerifier)rsassaVerifier);
        }
        if (clientPublicKey instanceof ECPublicKey) {
            ECDSAVerifier ecdsaVerifier = new ECDSAVerifier((ECPublicKey)clientPublicKey);
            return signedJWT.verify((JWSVerifier)ecdsaVerifier);
        }
        throw new ClientConfigurationException(new IllegalStateException("unknown public signing key: " + clientPublicKey.getAlgorithm()));
    }
}

