/*
 * Decompiled with CFR 0.152.
 */
package uk.gov.di.ipv.cri.common.library.service;

import com.nimbusds.jose.EncryptionMethod;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWEAlgorithm;
import com.nimbusds.jose.JWEDecrypter;
import com.nimbusds.jose.JWEHeader;
import com.nimbusds.jose.crypto.impl.AlgorithmSupportMessage;
import com.nimbusds.jose.crypto.impl.ContentCryptoProvider;
import com.nimbusds.jose.jca.JWEJCAContext;
import com.nimbusds.jose.util.Base64URL;
import java.util.Objects;
import java.util.Set;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.kms.KmsClient;
import software.amazon.awssdk.services.kms.model.DecryptRequest;
import software.amazon.awssdk.services.kms.model.DecryptResponse;
import software.amazon.awssdk.services.kms.model.EncryptionAlgorithmSpec;
import uk.gov.di.ipv.cri.common.library.util.EventProbe;

public class KMSRSADecrypter
implements JWEDecrypter {
    private static final Set<JWEAlgorithm> SUPPORTED_ALGORITHMS = Set.of(JWEAlgorithm.RSA_OAEP_256);
    private static final Set<EncryptionMethod> SUPPORTED_ENCRYPTION_METHODS = Set.of(EncryptionMethod.A256GCM);
    private static final Logger LOGGER = LogManager.getLogger();
    private static final String SESSION_DECRYPTION_KEY_PRIMARY_ALIAS = "session_decryption_key_active_alias";
    private static final String SESSION_DECRYPTION_KEY_SECONDARY_ALIAS = "session_decryption_key_inactive_alias";
    private static final String SESSION_DECRYPTION_KEY_PREVIOUS_ALIAS = "session_decryption_key_previous_alias";
    private static final String ALL_ALIASES_UNAVAILABLE = "all_aliases_unavailable_for_decryption";
    private boolean keyRotationEnabled = false;
    private boolean keyRotationLegacyKeyFallbackEnabled = false;
    private final JWEJCAContext jcaContext = new JWEJCAContext();
    private final KmsClient kmsClient;
    private final EventProbe eventProbe;
    private final String keyId;

    public KMSRSADecrypter(String keyId, KmsClient kmsClient, EventProbe eventProbe) {
        this(kmsClient, eventProbe, keyId, Boolean.parseBoolean(System.getenv("ENV_VAR_FEATURE_FLAG_KEY_ROTATION")), Boolean.parseBoolean(System.getenv("ENV_VAR_FEATURE_FLAG_KEY_ROTATION_LEGACY_KEY_FALLBACK")));
    }

    public KMSRSADecrypter(KmsClient kmsClient, EventProbe eventProbe, String keyId, Boolean keyRotationEnabled, boolean legacyKeyFallbackEnabled) {
        this.kmsClient = kmsClient;
        this.eventProbe = eventProbe;
        this.keyId = keyId;
        this.keyRotationEnabled = keyRotationEnabled;
        this.keyRotationLegacyKeyFallbackEnabled = legacyKeyFallbackEnabled;
    }

    public Set<JWEAlgorithm> supportedJWEAlgorithms() {
        return SUPPORTED_ALGORITHMS;
    }

    public Set<EncryptionMethod> supportedEncryptionMethods() {
        return SUPPORTED_ENCRYPTION_METHODS;
    }

    public JWEJCAContext getJCAContext() {
        return this.jcaContext;
    }

    public boolean isKeyRotationEnabled() {
        return this.keyRotationEnabled;
    }

    public byte[] decrypt(JWEHeader header, Base64URL encryptedKey, Base64URL iv, Base64URL cipherText, Base64URL authTag, byte[] aad) throws JOSEException {
        DecryptResponse decryptResponse;
        this.validateJwe(header, encryptedKey, iv, authTag);
        if (this.keyRotationEnabled) {
            LOGGER.info("Key rotation enabled. Attempting to decrypt with key aliases.");
            decryptResponse = this.decryptWithKeyAliases(encryptedKey);
            if (this.keyRotationLegacyKeyFallbackEnabled && decryptResponse == null) {
                LOGGER.warn("Failed to decrypt with all available key aliases, falling back to legacy key.");
                try {
                    decryptResponse = this.decryptWithLegacyKey(encryptedKey);
                }
                catch (Exception exception) {
                    // empty catch block
                }
                if (decryptResponse == null) {
                    String message = "Failed to decrypt with legacy key.";
                    LOGGER.error(message);
                    throw new JOSEException(message);
                }
                LOGGER.info("Decryption successful with legacy key");
            } else if (decryptResponse == null) {
                String message = "Failed to decrypt with all available key aliases.";
                LOGGER.error(message);
                throw new JOSEException(message);
            }
        } else {
            decryptResponse = this.decryptWithLegacyKey(encryptedKey);
        }
        SecretKeySpec cek = new SecretKeySpec(decryptResponse.plaintext().asByteArray(), "AES");
        return ContentCryptoProvider.decrypt((JWEHeader)header, null, (Base64URL)encryptedKey, (Base64URL)iv, (Base64URL)cipherText, (Base64URL)authTag, (SecretKey)cek, (JWEJCAContext)this.getJCAContext());
    }

    private void validateJwe(JWEHeader header, Base64URL encryptedKey, Base64URL iv, Base64URL authTag) throws JOSEException {
        if (Objects.isNull(encryptedKey)) {
            throw new JOSEException("Missing JWE encrypted key");
        }
        if (Objects.isNull(iv)) {
            throw new JOSEException("Missing JWE initialization vector (IV)");
        }
        if (Objects.isNull(authTag)) {
            throw new JOSEException("Missing JWE authentication tag");
        }
        JWEAlgorithm alg = header.getAlgorithm();
        if (!SUPPORTED_ALGORITHMS.contains(alg)) {
            throw new JOSEException(AlgorithmSupportMessage.unsupportedJWEAlgorithm((JWEAlgorithm)alg, this.supportedJWEAlgorithms()));
        }
    }

    private DecryptResponse decryptWithLegacyKey(Base64URL encryptedKey) {
        DecryptRequest decryptRequest = (DecryptRequest)DecryptRequest.builder().encryptionAlgorithm(EncryptionAlgorithmSpec.RSAES_OAEP_SHA_256).ciphertextBlob(SdkBytes.fromByteArray((byte[])encryptedKey.decode())).keyId(this.keyId).build();
        return this.kmsClient.decrypt(decryptRequest);
    }

    private DecryptResponse decryptWithKeyAliases(Base64URL encryptedKey) {
        String[] keyAliases = new String[]{SESSION_DECRYPTION_KEY_PRIMARY_ALIAS, SESSION_DECRYPTION_KEY_SECONDARY_ALIAS, SESSION_DECRYPTION_KEY_PREVIOUS_ALIAS};
        DecryptResponse decryptResponse = null;
        for (String alias : keyAliases) {
            try {
                decryptResponse = this.kmsClient.decrypt(this.buildDecryptRequest(alias, encryptedKey));
                LOGGER.info("Decryption successful with key alias: {}", (Object)alias);
                return decryptResponse;
            }
            catch (Exception e) {
                LOGGER.warn("Failed to decrypt with key alias: {}. Error: {}", (Object)alias, (Object)e.getMessage());
            }
        }
        this.eventProbe.counterMetric(ALL_ALIASES_UNAVAILABLE);
        return decryptResponse;
    }

    private DecryptRequest buildDecryptRequest(String keyAlias, Base64URL encryptedKey) {
        return (DecryptRequest)DecryptRequest.builder().ciphertextBlob(SdkBytes.fromByteArray((byte[])encryptedKey.decode())).encryptionAlgorithm(EncryptionAlgorithmSpec.RSAES_OAEP_SHA_256).keyId("alias/" + keyAlias).build();
    }
}

