/*
 * Decompiled with CFR 0.152.
 */
package com.amazonaws.encryptionsdk.internal;

import com.amazonaws.encryptionsdk.CryptoAlgorithm;
import com.amazonaws.encryptionsdk.DataKey;
import com.amazonaws.encryptionsdk.MasterKey;
import com.amazonaws.encryptionsdk.MasterKeyProvider;
import com.amazonaws.encryptionsdk.exception.AwsCryptoException;
import com.amazonaws.encryptionsdk.exception.BadCiphertextException;
import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException;
import com.amazonaws.encryptionsdk.internal.BlockDecryptionHandler;
import com.amazonaws.encryptionsdk.internal.CipherHandler;
import com.amazonaws.encryptionsdk.internal.CryptoHandler;
import com.amazonaws.encryptionsdk.internal.FrameDecryptionHandler;
import com.amazonaws.encryptionsdk.internal.MessageCryptoHandler;
import com.amazonaws.encryptionsdk.internal.ProcessingSummary;
import com.amazonaws.encryptionsdk.internal.Utils;
import com.amazonaws.encryptionsdk.model.CiphertextFooters;
import com.amazonaws.encryptionsdk.model.CiphertextHeaders;
import com.amazonaws.encryptionsdk.model.CiphertextType;
import com.amazonaws.encryptionsdk.model.ContentType;
import com.amazonaws.util.Base64;
import java.security.GeneralSecurityException;
import java.security.InvalidKeyException;
import java.security.PublicKey;
import java.security.Signature;
import java.security.SignatureException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import javax.crypto.SecretKey;
import org.bouncycastle.crypto.params.ECDomainParameters;
import org.bouncycastle.crypto.params.ECPublicKeyParameters;
import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey;
import org.bouncycastle.jce.ECNamedCurveTable;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.jce.spec.ECNamedCurveParameterSpec;
import org.bouncycastle.jce.spec.ECParameterSpec;
import org.bouncycastle.math.ec.ECPoint;

public class DecryptionHandler<K extends MasterKey<K>>
implements MessageCryptoHandler<K> {
    private final MasterKeyProvider<K> masterKeyProvider_;
    private final CiphertextHeaders ciphertextHeaders_;
    private final CiphertextFooters ciphertextFooters_;
    private boolean ciphertextHeadersParsed_;
    private CryptoHandler contentCryptoHandler_;
    private DataKey<K> dataKey_;
    private SecretKey decryptionKey_;
    private CryptoAlgorithm cryptoAlgo_;
    private PublicKey trailingPublicKey_;
    private Signature trailingSig_;
    private Map<String, String> encryptionContext_ = null;
    private byte[] unparsedBytes_ = new byte[0];
    private boolean complete_ = false;

    public DecryptionHandler(MasterKeyProvider<K> customerMasterKeyProvider) throws AwsCryptoException {
        Utils.assertNonNull(customerMasterKeyProvider, "customerMasterKeyProvider");
        this.masterKeyProvider_ = customerMasterKeyProvider;
        this.ciphertextHeaders_ = new CiphertextHeaders();
        this.ciphertextFooters_ = new CiphertextFooters();
    }

    public DecryptionHandler(MasterKeyProvider<K> customerMasterKeyProvider, CiphertextHeaders headers) throws AwsCryptoException {
        Utils.assertNonNull(customerMasterKeyProvider, "customerMasterKeyProvider");
        this.masterKeyProvider_ = customerMasterKeyProvider;
        this.ciphertextHeaders_ = headers;
        this.ciphertextFooters_ = new CiphertextFooters();
        this.readHeaderFields(headers);
        this.updateTrailingSignature(headers);
    }

    @Override
    public ProcessingSummary processBytes(byte[] in, int off, int len, byte[] out, int outOff) throws BadCiphertextException, AwsCryptoException {
        if (len < 0 || off < 0) {
            throw new AwsCryptoException(String.format("Invalid values for input offset: %d and length: %d", off, len));
        }
        if (in.length == 0 || len == 0) {
            return ProcessingSummary.ZERO;
        }
        long totalBytesToParse = (long)this.unparsedBytes_.length + (long)len;
        if (totalBytesToParse > Integer.MAX_VALUE) {
            throw new AwsCryptoException("Size of the total bytes to parse and decrypt exceeded allowed maximum:2147483647");
        }
        byte[] bytesToParse = new byte[(int)totalBytesToParse];
        int leftoverBytes = this.unparsedBytes_.length;
        System.arraycopy(this.unparsedBytes_, 0, bytesToParse, 0, this.unparsedBytes_.length);
        System.arraycopy(in, off, bytesToParse, this.unparsedBytes_.length, len);
        int totalParsedBytes = 0;
        if (!this.ciphertextHeadersParsed_) {
            totalParsedBytes += this.ciphertextHeaders_.deserialize(bytesToParse, 0);
            if (this.ciphertextHeaders_.isComplete().booleanValue()) {
                this.readHeaderFields(this.ciphertextHeaders_);
                this.updateTrailingSignature(this.ciphertextHeaders_);
                this.ciphertextHeadersParsed_ = true;
                this.unparsedBytes_ = new byte[0];
            } else {
                this.unparsedBytes_ = Arrays.copyOfRange(bytesToParse, totalParsedBytes, bytesToParse.length);
                return new ProcessingSummary(0, len);
            }
        }
        int actualOutLen = 0;
        if (!this.contentCryptoHandler_.isComplete()) {
            if (bytesToParse.length - totalParsedBytes > 0) {
                ProcessingSummary contentResult = this.contentCryptoHandler_.processBytes(bytesToParse, totalParsedBytes, bytesToParse.length - totalParsedBytes, out, outOff);
                this.updateTrailingSignature(bytesToParse, totalParsedBytes, contentResult.getBytesProcessed());
                actualOutLen = contentResult.getBytesWritten();
                totalParsedBytes += contentResult.getBytesProcessed();
            }
            if (this.contentCryptoHandler_.isComplete()) {
                actualOutLen += this.contentCryptoHandler_.doFinal(out, outOff + actualOutLen);
            }
        }
        if (this.contentCryptoHandler_.isComplete()) {
            totalParsedBytes += this.ciphertextFooters_.deserialize(bytesToParse, totalParsedBytes);
            if (this.ciphertextFooters_.isComplete() && this.trailingSig_ != null) {
                try {
                    if (!this.trailingSig_.verify(this.ciphertextFooters_.getMAuth())) {
                        throw new BadCiphertextException("Bad trailing signature");
                    }
                    this.complete_ = true;
                }
                catch (SignatureException ex) {
                    throw new BadCiphertextException("Bad trailing signature", ex);
                }
            }
        }
        return new ProcessingSummary(actualOutLen, totalParsedBytes - leftoverBytes);
    }

    @Override
    public int doFinal(byte[] out, int outOff) throws BadCiphertextException {
        if (this.contentCryptoHandler_ == null) {
            return 0;
        }
        int result = this.contentCryptoHandler_.doFinal(out, outOff);
        return result;
    }

    @Override
    public int estimateOutputSize(int inLen) {
        if (this.contentCryptoHandler_ != null) {
            return this.contentCryptoHandler_.estimateOutputSize(inLen);
        }
        return inLen > 0 ? inLen : 0;
    }

    @Override
    public Map<String, String> getEncryptionContext() {
        return this.encryptionContext_;
    }

    private void verifyHeaderIntegrity(CiphertextHeaders ciphertextHeaders) throws BadCiphertextException {
        CipherHandler cipherHandler = new CipherHandler(this.decryptionKey_, ciphertextHeaders.getHeaderNonce(), ciphertextHeaders.serializeAuthenticatedFields(), 2, this.cryptoAlgo_);
        try {
            byte[] headerTag = ciphertextHeaders.getHeaderTag();
            cipherHandler.cipherData(headerTag, 0, headerTag.length);
        }
        catch (BadCiphertextException e) {
            throw new BadCiphertextException("Header integrity check failed.", e);
        }
    }

    private DataKey<K> getDataKey(CiphertextHeaders ciphertextHeaders) {
        DataKey<K> result = this.masterKeyProvider_.decryptDataKey(this.cryptoAlgo_, ciphertextHeaders.getEncryptedKeyBlobs(), ciphertextHeaders.getEncryptionContextMap());
        if (result == null) {
            throw new CannotUnwrapDataKeyException("Could not decrypt any data keys");
        }
        return result;
    }

    private void readHeaderFields(CiphertextHeaders ciphertextHeaders) {
        byte version = ciphertextHeaders.getVersion();
        if (version != 1) {
            throw new BadCiphertextException("Invalid version in ciphertext.");
        }
        this.cryptoAlgo_ = ciphertextHeaders.getCryptoAlgoId();
        CiphertextType ciphertextType = ciphertextHeaders.getType();
        if (ciphertextType != CiphertextType.CUSTOMER_AUTHENTICATED_ENCRYPTED_DATA) {
            throw new BadCiphertextException("Invalid type in ciphertext.");
        }
        byte[] messageId = ciphertextHeaders.getMessageId();
        this.encryptionContext_ = ciphertextHeaders.getEncryptionContextMap();
        if (this.cryptoAlgo_.getTrailingSignatureLength() > 0) {
            try {
                this.trailingPublicKey_ = this.deserializeTrailingKeyFromEc(this.encryptionContext_.get("aws-crypto-public-key"));
                this.trailingSig_ = Signature.getInstance(this.cryptoAlgo_.getTrailingSignatureAlgo(), "BC");
                this.trailingSig_.initVerify(this.trailingPublicKey_);
            }
            catch (GeneralSecurityException ex) {
                throw new AwsCryptoException(ex);
            }
        } else {
            this.trailingPublicKey_ = null;
            this.trailingSig_ = null;
        }
        ContentType contentType = ciphertextHeaders.getContentType();
        short nonceLen = ciphertextHeaders.getNonceLength();
        int frameLen = ciphertextHeaders.getFrameLength();
        this.dataKey_ = this.getDataKey(ciphertextHeaders);
        try {
            this.decryptionKey_ = this.cryptoAlgo_.getEncryptionKeyFromDataKey(this.dataKey_.getKey(), ciphertextHeaders);
        }
        catch (InvalidKeyException ex) {
            throw new AwsCryptoException(ex);
        }
        this.verifyHeaderIntegrity(ciphertextHeaders);
        switch (contentType) {
            case FRAME: {
                this.contentCryptoHandler_ = new FrameDecryptionHandler(this.decryptionKey_, (byte)nonceLen, this.cryptoAlgo_, messageId, frameLen);
                break;
            }
            case SINGLEBLOCK: {
                this.contentCryptoHandler_ = new BlockDecryptionHandler(this.decryptionKey_, (byte)nonceLen, this.cryptoAlgo_, messageId);
                break;
            }
        }
    }

    private PublicKey deserializeTrailingKeyFromEc(String pubKey) throws GeneralSecurityException {
        ECNamedCurveParameterSpec ecSpec;
        switch (this.cryptoAlgo_) {
            case ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256: {
                ecSpec = ECNamedCurveTable.getParameterSpec((String)"secp256r1");
                break;
            }
            case ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384: 
            case ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384: {
                ecSpec = ECNamedCurveTable.getParameterSpec((String)"secp384r1");
                break;
            }
            default: {
                throw new IllegalStateException("Algorithm does not support trailing signature");
            }
        }
        ECPoint q = ecSpec.getCurve().decodePoint(Base64.decode((String)pubKey));
        ECPublicKeyParameters keyParams = new ECPublicKeyParameters(q, new ECDomainParameters(ecSpec.getCurve(), ecSpec.getG(), ecSpec.getN(), ecSpec.getH()));
        return new BCECPublicKey("ECDSA", keyParams, (ECParameterSpec)ecSpec, BouncyCastleProvider.CONFIGURATION);
    }

    private void updateTrailingSignature(CiphertextHeaders headers) {
        if (this.trailingSig_ != null) {
            byte[] reserializedHeaders = this.ciphertextHeaders_.toByteArray();
            this.updateTrailingSignature(reserializedHeaders, 0, reserializedHeaders.length);
        }
    }

    private void updateTrailingSignature(byte[] input, int offset, int len) {
        if (this.trailingSig_ != null) {
            try {
                this.trailingSig_.update(input, offset, len);
            }
            catch (SignatureException ex) {
                throw new AwsCryptoException(ex);
            }
        }
    }

    @Override
    public CiphertextHeaders getHeaders() {
        return this.ciphertextHeaders_;
    }

    @Override
    public List<K> getMasterKeys() {
        return Collections.singletonList(this.dataKey_.getMasterKey());
    }

    @Override
    public boolean isComplete() {
        return this.complete_;
    }
}

