/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cxf.rs.security.jose.jwe;

import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
import org.apache.cxf.common.logging.LogUtils;
import org.apache.cxf.common.util.Base64UrlUtility;
import org.apache.cxf.common.util.StringUtils;
import org.apache.cxf.rs.security.jose.jwa.AlgorithmUtils;
import org.apache.cxf.rs.security.jose.jwa.KeyAlgorithm;
import org.apache.cxf.rs.security.jose.jwe.AesWrapKeyEncryptionAlgorithm;
import org.apache.cxf.rs.security.jose.jwe.JweException;
import org.apache.cxf.rs.security.jose.jwe.JweHeaders;
import org.apache.cxf.rs.security.jose.jwe.KeyEncryptionProvider;
import org.apache.cxf.rt.security.crypto.CryptoUtils;
import org.apache.cxf.rt.security.crypto.MessageDigestUtils;
import org.bouncycastle.crypto.Digest;
import org.bouncycastle.crypto.digests.SHA256Digest;
import org.bouncycastle.crypto.digests.SHA384Digest;
import org.bouncycastle.crypto.digests.SHA512Digest;
import org.bouncycastle.crypto.generators.PKCS5S2ParametersGenerator;
import org.bouncycastle.crypto.params.KeyParameter;

public class PbesHmacAesWrapKeyEncryptionAlgorithm
implements KeyEncryptionProvider {
    protected static final Logger LOG = LogUtils.getL7dLogger(PbesHmacAesWrapKeyEncryptionAlgorithm.class);
    private static final Map<String, Integer> PBES_HMAC_MAP = new HashMap<String, Integer>();
    private static final Map<String, String> PBES_AES_MAP;
    private static final Map<String, Integer> DERIVED_KEY_SIZE_MAP;
    private byte[] password;
    private int pbesCount;
    private KeyAlgorithm keyAlgoJwt;

    public PbesHmacAesWrapKeyEncryptionAlgorithm(String password, KeyAlgorithm keyAlgoJwt) {
        this(PbesHmacAesWrapKeyEncryptionAlgorithm.stringToBytes(password), keyAlgoJwt);
    }

    public PbesHmacAesWrapKeyEncryptionAlgorithm(String password, int pbesCount, KeyAlgorithm keyAlgoJwt, boolean hashLargePasswords) {
        this(PbesHmacAesWrapKeyEncryptionAlgorithm.stringToBytes(password), pbesCount, keyAlgoJwt, hashLargePasswords);
    }

    public PbesHmacAesWrapKeyEncryptionAlgorithm(char[] password, KeyAlgorithm keyAlgoJwt) {
        this(password, 4096, keyAlgoJwt, false);
    }

    public PbesHmacAesWrapKeyEncryptionAlgorithm(char[] password, int pbesCount, KeyAlgorithm keyAlgoJwt, boolean hashLargePasswords) {
        this(PbesHmacAesWrapKeyEncryptionAlgorithm.charsToBytes(password), pbesCount, keyAlgoJwt, hashLargePasswords);
    }

    public PbesHmacAesWrapKeyEncryptionAlgorithm(byte[] password, KeyAlgorithm keyAlgoJwt) {
        this(password, 4096, keyAlgoJwt, false);
    }

    public PbesHmacAesWrapKeyEncryptionAlgorithm(byte[] password, int pbesCount, KeyAlgorithm keyAlgoJwt, boolean hashLargePasswords) {
        this.keyAlgoJwt = PbesHmacAesWrapKeyEncryptionAlgorithm.validateKeyAlgorithm(keyAlgoJwt);
        this.password = PbesHmacAesWrapKeyEncryptionAlgorithm.validatePassword(password, keyAlgoJwt.getJwaName(), hashLargePasswords);
        this.pbesCount = PbesHmacAesWrapKeyEncryptionAlgorithm.validatePbesCount(pbesCount);
    }

    static byte[] validatePassword(byte[] p, String keyAlgoJwt, boolean hashLargePasswords) {
        int minLen = DERIVED_KEY_SIZE_MAP.get(keyAlgoJwt);
        if (p.length < minLen || p.length > 128) {
            LOG.warning("Invalid password length: " + p.length);
            throw new JweException(JweException.Error.KEY_ENCRYPTION_FAILURE);
        }
        if (p.length > minLen && hashLargePasswords) {
            try {
                return MessageDigestUtils.createDigest((byte[])p, (String)"SHA-256");
            }
            catch (Exception ex) {
                LOG.warning("Password hash calculation error");
                throw new JweException(JweException.Error.KEY_ENCRYPTION_FAILURE, (Throwable)ex);
            }
        }
        return p;
    }

    @Override
    public byte[] getEncryptedContentEncryptionKey(JweHeaders headers, byte[] cek) {
        int keySize = PbesHmacAesWrapKeyEncryptionAlgorithm.getKeySize(this.keyAlgoJwt.getJwaName());
        byte[] saltInput = CryptoUtils.generateSecureRandomBytes((int)keySize);
        byte[] derivedKey = PbesHmacAesWrapKeyEncryptionAlgorithm.createDerivedKey(this.keyAlgoJwt.getJwaName(), keySize, this.password, saltInput, this.pbesCount);
        headers.setHeader("p2s", Base64UrlUtility.encode((byte[])saltInput));
        headers.setIntegerHeader("p2c", this.pbesCount);
        AesWrapKeyEncryptionAlgorithm aesWrap = new AesWrapKeyEncryptionAlgorithm(derivedKey, this.keyAlgoJwt){

            @Override
            protected void checkAlgorithms(JweHeaders headers) {
            }

            @Override
            protected String getKeyEncryptionAlgoJava(JweHeaders headers) {
                return "AESWrap";
            }
        };
        return aesWrap.getEncryptedContentEncryptionKey(headers, cek);
    }

    static int getKeySize(String keyAlgoJwt) {
        return DERIVED_KEY_SIZE_MAP.get(keyAlgoJwt);
    }

    static byte[] createDerivedKey(String keyAlgoJwt, int keySize, byte[] password, byte[] saltInput, int pbesCount) {
        byte[] saltValue = PbesHmacAesWrapKeyEncryptionAlgorithm.createSaltValue(keyAlgoJwt, saltInput);
        Object digest = null;
        int macSigSize = PBES_HMAC_MAP.get(keyAlgoJwt);
        digest = macSigSize == 256 ? new SHA256Digest() : (macSigSize == 384 ? new SHA384Digest() : new SHA512Digest());
        PKCS5S2ParametersGenerator gen = new PKCS5S2ParametersGenerator((Digest)digest);
        gen.init(password, saltValue, pbesCount);
        return ((KeyParameter)gen.generateDerivedParameters(keySize * 8)).getKey();
    }

    private static byte[] createSaltValue(String keyAlgoJwt, byte[] saltInput) {
        byte[] algoBytes = PbesHmacAesWrapKeyEncryptionAlgorithm.stringToBytes(keyAlgoJwt);
        byte[] saltValue = new byte[algoBytes.length + 1 + saltInput.length];
        System.arraycopy(algoBytes, 0, saltValue, 0, algoBytes.length);
        saltValue[algoBytes.length] = 0;
        System.arraycopy(saltInput, 0, saltValue, algoBytes.length + 1, saltInput.length);
        return saltValue;
    }

    static KeyAlgorithm validateKeyAlgorithm(KeyAlgorithm algo) {
        if (!AlgorithmUtils.isPbesHsWrap(algo.getJwaName())) {
            LOG.warning("Invalid key encryption algorithm");
            throw new JweException(JweException.Error.INVALID_KEY_ALGORITHM);
        }
        return algo;
    }

    static int validatePbesCount(int count) {
        if (count < 1000) {
            LOG.warning("Iteration count is too low");
            throw new JweException(JweException.Error.KEY_ENCRYPTION_FAILURE);
        }
        return count;
    }

    static byte[] stringToBytes(String str) {
        return StringUtils.toBytesUTF8((String)str);
    }

    static byte[] charsToBytes(char[] chars) {
        ByteBuffer bb = Charset.forName("UTF-8").encode(CharBuffer.wrap(chars));
        byte[] b = new byte[bb.remaining()];
        bb.get(b);
        return b;
    }

    @Override
    public KeyAlgorithm getAlgorithm() {
        return this.keyAlgoJwt;
    }

    static {
        PBES_HMAC_MAP.put(KeyAlgorithm.PBES2_HS256_A128KW.getJwaName(), 256);
        PBES_HMAC_MAP.put(KeyAlgorithm.PBES2_HS384_A192KW.getJwaName(), 384);
        PBES_HMAC_MAP.put(KeyAlgorithm.PBES2_HS512_A256KW.getJwaName(), 512);
        PBES_AES_MAP = new HashMap<String, String>();
        PBES_AES_MAP.put(KeyAlgorithm.PBES2_HS256_A128KW.getJwaName(), KeyAlgorithm.A128KW.getJwaName());
        PBES_AES_MAP.put(KeyAlgorithm.PBES2_HS384_A192KW.getJwaName(), KeyAlgorithm.A192KW.getJwaName());
        PBES_AES_MAP.put(KeyAlgorithm.PBES2_HS512_A256KW.getJwaName(), KeyAlgorithm.A256KW.getJwaName());
        DERIVED_KEY_SIZE_MAP = new HashMap<String, Integer>();
        DERIVED_KEY_SIZE_MAP.put(KeyAlgorithm.PBES2_HS256_A128KW.getJwaName(), 16);
        DERIVED_KEY_SIZE_MAP.put(KeyAlgorithm.PBES2_HS384_A192KW.getJwaName(), 24);
        DERIVED_KEY_SIZE_MAP.put(KeyAlgorithm.PBES2_HS512_A256KW.getJwaName(), 32);
    }
}

