/*
 * Copyright 2021-2024 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.metaeffekt.artifact.analysis.flow.ng;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.metaeffekt.artifact.analysis.flow.ng.exception.DecryptionImpossibleException;
import com.metaeffekt.artifact.analysis.flow.ng.keyholder.UserKeysForConsumer;
import com.metaeffekt.artifact.analysis.flow.ng.keyholder.UserKeysForSupplier;
import com.metaeffekt.artifact.analysis.flow.ng.keyholder.UserKeysWithHmacSecret;
import org.apache.commons.codec.digest.HmacAlgorithms;
import org.apache.commons.codec.digest.HmacUtils;
import org.bouncycastle.crypto.*;
import org.bouncycastle.crypto.digests.SHA256Digest;
import org.bouncycastle.crypto.engines.AESWrapEngine;
import org.bouncycastle.crypto.generators.KDF2BytesGenerator;
import org.bouncycastle.crypto.kems.RSAKEMExtractor;
import org.bouncycastle.crypto.kems.RSAKEMGenerator;
import org.bouncycastle.crypto.params.KeyParameter;
import org.bouncycastle.util.BigIntegers;

import javax.crypto.Mac;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.security.MessageDigest;
import java.security.SecureRandom;

/**
 * The intermediate hmac and ciphertexts required to find one's slot and decrypt the content key.
 */
public class DecryptableKeyslot {
    /**
     * Encapsulated key, which will can decrypt the content encryption key.
     */
    @JsonProperty("encapsulatedKek")
    private byte[] encapsulatedKek;
    /**
     * The encrypted content encryption key. The resulting key is able to decrypt the corresponding content file.
     */
    @JsonProperty("wrappedContentKey")
    private byte[] wrappedContentKey;
    /**
     * The hmac, calculated over the encapsulatedKey.
     */
    @JsonProperty("slotHmac")
    private byte[] slotHmac;
    /**
     * Size of the extractable content key in bytes.
     */
    @JsonProperty("contentKeyLength")
    private byte[] contentKeyLength;

    SecureRandom secureRandom = new SecureRandom();

    public static DecryptableKeyslot createDecryptableKeyslot(byte[] contentKey,
                                                              UserKeysForSupplier userKeysForSupplier) {
        return createDecryptableKeyslot(contentKey, userKeysForSupplier, new SecureRandom());
    }

    /**
     * Creates a new keyslot from the given keys.
     * @param contentKey the content key
     * @param userKeysForSupplier key package containing all keys meeded to access this keyslot
     * @param secureRandom random used in keyslot (in-between keys) generation
     * @return returns a new keyslot object
     */
    public static DecryptableKeyslot createDecryptableKeyslot(byte[] contentKey,
                                                              UserKeysForSupplier userKeysForSupplier,
                                                              SecureRandom secureRandom) {
        // initialize key encapsulation. will output a random key, use the same keysize as contentKey for intermediates
        EncapsulatedSecretGenerator generator = new RSAKEMGenerator(contentKey.length,
                new KDF2BytesGenerator(new SHA256Digest()),
                secureRandom
        );

        // generate an encapsulated key using the recipient's public key. this is the KEK for the contentKey.
        SecretWithEncapsulation secretWithEncapsulation =
                generator.generateEncapsulated(userKeysForSupplier.getRsaPublicKey());

        // wrap the contentKey in a symmetric wrap. the KEK is the key the RSA-KEM generator just gave us.
        // the fact that such "Key Wrapping Schemes" are deterministic annoys me.
        // it's not very important here because the keys are randomly generated every run.
        AESWrapEngine aesWrapEngine = new AESWrapEngine();
        KeyParameter keyParameter = new KeyParameter(secretWithEncapsulation.getSecret());
        aesWrapEngine.init(true, keyParameter);

        // wrap the key, store the wrapped key
        byte[] wrappedContentKey = aesWrapEngine.wrap(contentKey, 0, contentKey.length);

        // assemble the entry and return it
        DecryptableKeyslot entry = new DecryptableKeyslot();

        entry.encapsulatedKek = secretWithEncapsulation.getEncapsulation();
        entry.wrappedContentKey = wrappedContentKey;
        entry.contentKeyLength = ByteBuffer.allocate(2).putShort((short) contentKey.length).array();

        // generate HMAC over the ciphertexts
        entry.slotHmac = entry.calculateHmac(userKeysForSupplier);

        return entry;
    }

    public byte[] getContentKeyUsing(UserKeysForConsumer userKeys) throws InvalidCipherTextException {
        if (!checkHmac(userKeys)) {
            // happens if the supplied keys don't match this slot
            throw new DecryptionImpossibleException("Can't decrypt this slot using the provided keys");
        }

        short keyLength = ByteBuffer.wrap(this.contentKeyLength).getShort();

        EncapsulatedSecretExtractor extractor = new RSAKEMExtractor(
                userKeys.getRsaPrivateKey(),
                keyLength,
                getSha256Kdf2()
        );

        byte[] kek = extractor.extractSecret(encapsulatedKek);

        AESWrapEngine aesWrapEngine = new AESWrapEngine();
        KeyParameter keyParameter = new KeyParameter(kek);
        aesWrapEngine.init(false, keyParameter);

        try {
            return aesWrapEngine.unwrap(wrappedContentKey, 0, wrappedContentKey.length);
        } catch (InvalidCipherTextException e) {
            // likely the keyslot or keys have been tampered with
            throw new DecryptionImpossibleException("Can't unwrap. Possible tampering detected!", e);
        }
    }

    private DerivationFunction getSha256Kdf2() {
        return new KDF2BytesGenerator(new SHA256Digest());
    }

    /**
     * Creates a useless keyslot, imitating a real one for chaffing, hiding the number of real slots.
     * @param random the SecureRandom object to use in generation.
     * @param realContentKeyLength length of the real content key used when generating real keyslots.
     * @return a bogus keyslot, resembling a real one but generated with useless keys and better performance.
     */
    public static DecryptableKeyslot generateBogusKeyslot(SecureRandom random, int realContentKeyLength) {
        // start of by generating random "content key" and a random hmac key, as an actual user might have
        // the point is: neither of these will correspond to any keys that are actually in use
        byte[] bogusContentKey = new byte[realContentKeyLength];
        random.nextBytes(bogusContentKey);
        byte[] bogusHmacKey = new byte[32];
        random.nextBytes(bogusHmacKey);

        // this public mod will never actually be used and only serves as a guide for fake ciphertext generation
        BigInteger bogusPublicMod = BigIntegers.createRandomBigInteger(
                KeyConstants.rsaSize,
                random
        );

        // bad fake of how rsa ciphertext generation might work (exponentiation, then modulus)
        BigInteger bogusCiphertext = BigIntegers.createRandomInRange(BigInteger.ZERO, bogusPublicMod, random);

        // start writing results to the entry
        DecryptableKeyslot entry = new DecryptableKeyslot();

        // our ciphertext is the fake encapsulated key encryption key
        // the extracted kek from this would otherwise be used in aes wrapping.
        entry.encapsulatedKek = BigIntegers.asUnsignedByteArray((bogusPublicMod.bitLength() + 7) / 8, bogusCiphertext);

        // fake extracted key from encapsulation. same length as the content key by default
        byte[] bogusEncapsulatedSecret = new byte[realContentKeyLength];
        random.nextBytes(bogusEncapsulatedSecret);

        // aes should be pretty fast so i just copy what's done in actual keyslot generation but with bogus key and data
        AESWrapEngine aesWrapEngine = new AESWrapEngine();
        KeyParameter keyParameter = new KeyParameter(bogusEncapsulatedSecret);
        aesWrapEngine.init(true, keyParameter);

        // fake wrapped key
        entry.wrappedContentKey = aesWrapEngine.wrap(bogusContentKey, 0, realContentKeyLength);

        // calculate hmac using bogus keys and bogus ciphertexts
        entry.slotHmac = entry.calculateHmac(bogusHmacKey);

        entry.contentKeyLength = ByteBuffer.allocate(2).putShort((short) realContentKeyLength).array();

        return entry;
    }

    /**
     * Checks a recalculated hmac against the one stored in the keyslot.
     * <br>
     * Matching implies that this slot really is valid for the given keys.
     * @param userKeys the keys to use in recalcuulating the mac
     * @return true if the hmac matches
     */
    public boolean checkHmac(UserKeysWithHmacSecret userKeys) {
        if (slotHmac == null || slotHmac.length == 0) {
            return false;
        }

        return hmacEquals(calculateHmac(userKeys));
    }

    /**
     * Compares hmacs, attempting to achieve decent security through a "double randomized hmac".
     *
     * @param toCompare non-null hmac to compare against this keyslot's hmac
     * @return true if the two hmacs are the same
     */
    private boolean hmacEquals(byte[] toCompare) {
        // refusing to compare nulls
        if (toCompare == null || toCompare.length == 0) {
            return false;
        }

        // generate random key for double hmac comparison
        byte[] randomKey = new byte[32];
        secureRandom.nextBytes(randomKey);

        // implements a randomized double hmac compare. relies on the hash in hmac to be pseudorandom.
        // simple (hopefully effective?) way to mitigate timing attacks (should not be viable here anyway).

        // use a larger hash for the temporary value
        Mac hmacGenerator = HmacUtils.getInitializedMac(HmacAlgorithms.HMAC_SHA_512, randomKey);
        byte[] doubleMacToCompare = hmacGenerator.doFinal(toCompare);
        byte[] doubleMacInternal =  hmacGenerator.doFinal(slotHmac);

        // use any compare algorithm to compare the results
        return MessageDigest.isEqual(doubleMacInternal, doubleMacToCompare);
    }

    /**
     * Calculates hmac over encapsulatedKek, wrappedContentKey.
     * @param userKeys them hmac key to use for hmac
     * @return returns the hmac digest output
     */
    private byte[] calculateHmac(UserKeysWithHmacSecret userKeys) {
        return calculateHmac(userKeys.getHmacSecretKey());
    }

    /**
     * Calculates hmac over encapsulatedKek, wrappedContentKey.
     * @param key them hmac key to use for hmac
     * @return returns the hmac digest output
     */
    private byte[] calculateHmac(byte[] key) {
        Mac hmacGenerator = HmacUtils.getInitializedMac(HmacAlgorithms.HMAC_SHA_256, key);
        hmacGenerator.update(encapsulatedKek);
        hmacGenerator.update(wrappedContentKey);
        hmacGenerator.update(contentKeyLength);

        return hmacGenerator.doFinal();
    }

    public byte[] getEncapsulatedKek() {
        return encapsulatedKek.clone();
    }

    /**
     * Gets the content key in its ciphertext form.
     * @return returns the encrypted content key
     */
    public byte[] getWrappedContentKey() {
        return wrappedContentKey.clone();
    }

    /**
     * Get the stored hmac for this slot.
     * @return returns the stored hmac for this slot.
     */
    public byte[] getSlotHmac() {
        return slotHmac.clone();
    }

    /**
     * Get the stored length of the content key.
     * <br>
     * This is important since key encapsulation doesn't store the length of the extracted key by itself.
     * @return returns the intended length of the content key
     */
    public byte[] getContentKeyLength() {
        return contentKeyLength;
    }
}
