//---------------------------------------------------------------------------------------------------------------------------------
// File: SQLServerColumnEncryptionAzureKeyVaultProvider.java
//
//
// Microsoft JDBC Driver for SQL Server
// Copyright(c) Microsoft Corporation
// All rights reserved.
// MIT License
// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files(the "Software"), 
//  to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, 
//  and / or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions :
// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
// THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 
//  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 
//  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 
//  IN THE SOFTWARE.
//---------------------------------------------------------------------------------------------------------------------------------
 
 
package com.microsoft.sqlserver.jdbc;

import static java.nio.charset.StandardCharsets.UTF_16LE;

import java.net.URI;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.text.MessageFormat;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;

import org.apache.http.impl.client.HttpClientBuilder;

import com.microsoft.azure.keyvault.KeyVaultClient;
import com.microsoft.azure.keyvault.KeyVaultClientImpl;
import com.microsoft.azure.keyvault.models.KeyBundle;
import com.microsoft.azure.keyvault.models.KeyOperationResult;
import com.microsoft.azure.keyvault.webkey.JsonWebKeySignatureAlgorithm;

/**
 * Provides implementation similar to certificate store provider.
 *  A CEK encrypted with certificate store provider should be decryptable by this provider and vice versa.
 * 
 * Envolope Format for the encrypted column encryption key  
 *         version + keyPathLength + ciphertextLength + keyPath + ciphertext +  signature
 * version: A single byte indicating the format version.
 *  keyPathLength: Length of the keyPath.
 *  ciphertextLength: ciphertext length
 *  keyPath: keyPath used to encrypt the column encryption key. This is only used for troubleshooting purposes and is not verified during decryption.
 *  ciphertext: Encrypted column encryption key
 *  signature: Signature of the entire byte array. Signature is validated before decrypting the column encryption key.
 */
public class SQLServerColumnEncryptionAzureKeyVaultProvider extends SQLServerColumnEncryptionKeyStoreProvider{

	/**
	 * Column Encryption Key Store Provider string
	 */
	String name = "AZURE_KEY_VAULT";

	private final String azureKeyVaultDomainName = "vault.azure.net";

	private final String rsaEncryptionAlgorithmWithOAEPForAKV="RSA-OAEP";

	/**
	 * Algorithm version
	 */
	private final byte[] firstVersion = new byte[] { 0x01 };

	private KeyVaultClient keyVaultClient;

	private KeyVaultCredential credential;

	
	public void setName(String name)
	{
		this.name = name; 
	}

	public String getName()
	{
		return this.name;
	}
	
	/**
	 * Constructor that takes a callback function to authenticate to AAD. This is used by KeyVaultClient at runtime 
	 * to authenticate to Azure Key Vault.
	 * 
	 * @param authenticationCallback - Callback function used for authenticating to AAD.
	 * @param executorService - The ExecutorService used to create the keyVaultClient
	 * @throws SQLServerException when an error occurs
	 */
	public SQLServerColumnEncryptionAzureKeyVaultProvider(SQLServerKeyVaultAuthenticationCallback authenticationCallback, ExecutorService executorService) throws SQLServerException{
		if(null == authenticationCallback){
	       	 MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NullValue"));
	       	 Object[] msgArgs1 = { "SQLServerKeyVaultAuthenticationCallback" };
	       	 throw new SQLServerException(form.format(msgArgs1), null);
		}
		credential = new KeyVaultCredential(authenticationCallback);
		HttpClientBuilder builder = HttpClientBuilder.create();
		keyVaultClient = new KeyVaultClientImpl(builder, executorService, credential);
	}

	/**
	 * This function uses the asymmetric key specified by the key path
	 * and decrypts an encrypted CEK with RSA encryption algorithm.
	 * 
	 * @param masterKeyPath - Complete path of an asymmetric key in AKV
	 * @param encryptionAlgorithm - Asymmetric Key Encryption Algorithm
	 * @param encryptedColumnEncryptionKey - Encrypted Column Encryption Key
	 * @return Plain text column encryption key
	 */
	@Override
	public byte[] decryptColumnEncryptionKey(String masterKeyPath, String encryptionAlgorithm,
			byte[] encryptedColumnEncryptionKey) throws SQLServerException {

		// Validate the input parameters
		this.ValidateNonEmptyAKVPath(masterKeyPath);

		if (null == encryptedColumnEncryptionKey)
		{
			throw new SQLServerException(SQLServerException.getErrString("R_NullEncryptedColumnEncryptionKey"), null);
		}

		if (0 == encryptedColumnEncryptionKey.length)
		{
			throw new SQLServerException(SQLServerException.getErrString("R_EmptyEncryptedColumnEncryptionKey"), null);
		}

		// Validate encryptionAlgorithm
		encryptionAlgorithm = this.validateEncryptionAlgorithm(encryptionAlgorithm);

		// Validate whether the key is RSA one or not and then get the key size
		int keySizeInBytes = getAKVKeySize(masterKeyPath);

		// Validate and decrypt the EncryptedColumnEncryptionKey
		// Format is 
		//           version + keyPathLength + ciphertextLength + keyPath + ciphertext +  signature
		//
		// keyPath is present in the encrypted column encryption key for identifying the original source of the asymmetric key pair and 
		// we will not validate it against the data contained in the CMK metadata (masterKeyPath).

		// Validate the version byte
		if (encryptedColumnEncryptionKey[0] != firstVersion[0])
		{
			MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_InvalidEcryptionAlgorithmVersion"));
			Object[] msgArgs = {String.format("%02X ", encryptedColumnEncryptionKey[0]), String.format("%02X ", firstVersion[0])};
			throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
		}

		// Get key path length
		int currentIndex = firstVersion.length;
		short keyPathLength = convertTwoBytesToShort(encryptedColumnEncryptionKey, currentIndex);
		// We just read 2 bytes
		currentIndex += 2;

		// Get ciphertext length
		short cipherTextLength = convertTwoBytesToShort(encryptedColumnEncryptionKey, currentIndex);
		currentIndex += 2;

		// Skip KeyPath
		// KeyPath exists only for troubleshooting purposes and doesnt need validation.
		currentIndex += keyPathLength;

		// validate the ciphertext length
		if (cipherTextLength != keySizeInBytes)
		{
			MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVKeyLengthError"));
			Object[] msgArgs = {cipherTextLength, keySizeInBytes, masterKeyPath};
			throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
		}

		// Validate the signature length
		int signatureLength = encryptedColumnEncryptionKey.length - currentIndex - cipherTextLength;

		if (signatureLength != keySizeInBytes)
		{            
			MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVSignatureLengthError"));
			Object[] msgArgs = {signatureLength, keySizeInBytes, masterKeyPath};
			throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
		}

		// Get ciphertext
		byte[] cipherText = new byte[cipherTextLength];
		System.arraycopy(encryptedColumnEncryptionKey,currentIndex,cipherText,0,cipherTextLength);
		currentIndex += cipherTextLength;

		// Get signature
		byte[] signature = new byte[signatureLength];
		System.arraycopy(encryptedColumnEncryptionKey,currentIndex,signature,0,signatureLength);

		// Compute the hash to validate the signature
		byte[] hash = new byte[encryptedColumnEncryptionKey.length - signature.length];

		System.arraycopy(encryptedColumnEncryptionKey,0,hash,0,encryptedColumnEncryptionKey.length - signature.length);


		MessageDigest md = null;
		try {
			md = MessageDigest.getInstance("SHA-256");
		} catch (NoSuchAlgorithmException e) {
			throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), null);
		}
		md.update(hash);
		byte dataToVerify[] = md.digest();

		if (null == dataToVerify)
		{
			throw new SQLServerException(SQLServerException.getErrString("R_HashNull"), null);
		}

		// Validate the signature
		if (!AzureKeyVaultVerifySignature(dataToVerify, signature, masterKeyPath))
		{
			MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_CEKSignatureNotMatchCMK"));
			Object[] msgArgs = {masterKeyPath};
			throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
		}

		// Decrypt the CEK
		byte[] decryptedCEK = this.AzureKeyVaultUnWrap(masterKeyPath, encryptionAlgorithm, cipherText);

		return decryptedCEK;
	}

	private short convertTwoBytesToShort(byte[] input, int index) throws SQLServerException
	{

		short shortVal = -1;
		if (index + 1 >= input.length)
		{
			throw new SQLServerException(
					null,
					SQLServerException.getErrString("R_ByteToShortConversion"),
					null,
					0,
					false);
		}
		ByteBuffer byteBuffer = ByteBuffer.allocate(2);
		byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
		byteBuffer.put(input[index]);
		byteBuffer.put(input[index + 1]);
		shortVal = byteBuffer.getShort(0);
		return shortVal;

	}

	/**
	 * This function uses the asymmetric key specified by the key path
	 * and encrypts CEK with RSA encryption algorithm.
	 * 
	 * @param masterKeyPath - Complete path of an asymmetric key in AKV
	 * @param encryptionAlgorithm - Asymmetric Key Encryption Algorithm
	 * @param columnEncryptionKey - Plain text column encryption key
	 * @return Encrypted column encryption key
	 */
	@Override
	public byte[] encryptColumnEncryptionKey(String masterKeyPath, String encryptionAlgorithm,
			byte[] columnEncryptionKey) throws SQLServerException {

		// Validate the input parameters
		this.ValidateNonEmptyAKVPath(masterKeyPath);

		if (null == columnEncryptionKey)
		{
			throw new SQLServerException(SQLServerException.getErrString("R_NullColumnEncryptionKey"), null);
		}

		if (0 == columnEncryptionKey.length)
		{
			throw new SQLServerException(SQLServerException.getErrString("R_EmptyCEK"), null);
		}

		// Validate encryptionAlgorithm
		encryptionAlgorithm = this.validateEncryptionAlgorithm(encryptionAlgorithm);

		// Validate whether the key is RSA one or not and then get the key size
		int keySizeInBytes = getAKVKeySize(masterKeyPath);

		// Construct the encryptedColumnEncryptionKey
		// Format is 
		//          version + keyPathLength + ciphertextLength + ciphertext + keyPath + signature
		//
		// We currently only support one version
		byte[] version = new byte[] { firstVersion[0] };

		// Get the Unicode encoded bytes of cultureinvariant lower case masterKeyPath
		byte[] masterKeyPathBytes = masterKeyPath.toLowerCase().getBytes(UTF_16LE);

		byte[] keyPathLength = new byte[2];
		keyPathLength[0] = (byte)(((short)masterKeyPathBytes.length) & 0xff);
		keyPathLength[1] = (byte)(((short)masterKeyPathBytes.length) >> 8 & 0xff);

		// Encrypt the plain text
		byte[] cipherText = this.AzureKeyVaultWrap(masterKeyPath, encryptionAlgorithm, columnEncryptionKey);

		byte[] cipherTextLength = new byte[2];
		cipherTextLength[0] = (byte)(((short)cipherText.length) & 0xff);
		cipherTextLength[1] = (byte)(((short)cipherText.length) >> 8 & 0xff);

		if (cipherText.length != keySizeInBytes)
		{
			throw new SQLServerException(SQLServerException.getErrString("R_CipherTextLengthNotMatchRSASize"), null);
		}

		// Compute hash
		// SHA-2-256(version + keyPathLength + ciphertextLength + keyPath + ciphertext) 
		byte [] dataToHash=new byte[version.length + keyPathLength.length + cipherTextLength.length + masterKeyPathBytes.length + cipherText.length];
		int destinationPosition=version.length;
		System.arraycopy(version, 0, dataToHash, 0, version.length);

		System.arraycopy(keyPathLength, 0, dataToHash, destinationPosition, keyPathLength.length);
		destinationPosition+=keyPathLength.length;

		System.arraycopy(cipherTextLength, 0, dataToHash, destinationPosition, cipherTextLength.length);
		destinationPosition+=cipherTextLength.length;

		System.arraycopy(masterKeyPathBytes, 0, dataToHash, destinationPosition, masterKeyPathBytes.length);
		destinationPosition+=masterKeyPathBytes.length;

		System.arraycopy(cipherText, 0, dataToHash, destinationPosition, cipherText.length); 

		MessageDigest md = null;
		try {
			md = MessageDigest.getInstance("SHA-256");
		} catch (NoSuchAlgorithmException e) {
			throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), null);
		}
		md.update(dataToHash);
		byte dataToSign[] = md.digest();

		// Sign the hash
		byte[] signedHash = null;
		signedHash = AzureKeyVaultSignHashedData(dataToSign, masterKeyPath);

		if (signedHash.length != keySizeInBytes)
		{
			throw new SQLServerException(SQLServerException.getErrString("R_SignedHashLengthError"), null);
		}

		if (!this.AzureKeyVaultVerifySignature(dataToSign, signedHash, masterKeyPath))
		{
			throw new SQLServerException(SQLServerException.getErrString("R_InvalidSignatureComputed"), null);
		}

		// Construct the encrypted column encryption key
		// EncryptedColumnEncryptionKey = version + keyPathLength + ciphertextLength + keyPath + ciphertext +  signature
		int encryptedColumnEncryptionKeyLength = version.length + cipherTextLength.length + keyPathLength.length + cipherText.length + masterKeyPathBytes.length + signedHash.length;
		byte[] encryptedColumnEncryptionKey = new byte[encryptedColumnEncryptionKeyLength];

		// Copy version byte
		int currentIndex = 0;
		System.arraycopy(version, 0, encryptedColumnEncryptionKey, currentIndex, version.length);
		currentIndex += version.length;

		// Copy key path length
		System.arraycopy(keyPathLength, 0, encryptedColumnEncryptionKey, currentIndex, keyPathLength.length);
		currentIndex += keyPathLength.length;

		// Copy ciphertext length
		System.arraycopy(cipherTextLength, 0, encryptedColumnEncryptionKey, currentIndex, cipherTextLength.length);
		currentIndex += cipherTextLength.length;

		// Copy key path
		System.arraycopy(masterKeyPathBytes, 0, encryptedColumnEncryptionKey, currentIndex, masterKeyPathBytes.length);
		currentIndex += masterKeyPathBytes.length;

		// Copy ciphertext
		System.arraycopy(cipherText, 0, encryptedColumnEncryptionKey, currentIndex, cipherText.length);
		currentIndex += cipherText.length;

		// copy the signature
		System.arraycopy(signedHash, 0, encryptedColumnEncryptionKey, currentIndex, signedHash.length);

		return encryptedColumnEncryptionKey;
	}

	/**
	 * This function validates that the encryption algorithm is RSA_OAEP and if it is not,
	 * then throws an exception
	 * @param encryptionAlgorithm - Asymmetric key encryptio algorithm
	 * @return The encryption algorithm that is going to be used.
	 * @throws SQLServerException
	 */
	private String validateEncryptionAlgorithm(String encryptionAlgorithm) throws SQLServerException
	{

		if (null == encryptionAlgorithm)
		{
			throw new SQLServerException(
					null,
					SQLServerException.getErrString("R_NullKeyEncryptionAlgorithm"),
					null,
					0,
					false);
		}

		// Transform to standard format (dash instead of underscore) to support both "RSA_OAEP" and "RSA-OAEP"
		if (encryptionAlgorithm.equalsIgnoreCase("RSA_OAEP"))
		{
			encryptionAlgorithm = "RSA-OAEP";
		}

		if (!rsaEncryptionAlgorithmWithOAEPForAKV.equalsIgnoreCase(encryptionAlgorithm.trim()))
		{
			MessageFormat form = new MessageFormat(
					SQLServerException.getErrString("R_InvalidKeyEncryptionAlgorithm"));
			Object[] msgArgs = { encryptionAlgorithm, rsaEncryptionAlgorithmWithOAEPForAKV };
			throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
		}


		return encryptionAlgorithm;
	}

	/**
	 * Checks if the Azure Key Vault key path is Empty or Null (and raises exception if they are).
	 * 
	 * @param masterKeyPath
	 * @throws SQLServerException
	 */
	private void ValidateNonEmptyAKVPath(String masterKeyPath) throws SQLServerException
	{
		// throw appropriate error if masterKeyPath is null or empty
		if (null == masterKeyPath || masterKeyPath.trim().isEmpty())
		{
			MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVPathNull"));
			Object[] msgArgs = {masterKeyPath};
			throw new SQLServerException(null , form.format(msgArgs) , null, 0 , false);    
		}
		else
		{
			URI parsedUri = null;
			try {
				parsedUri = new URI(masterKeyPath);
			} catch (URISyntaxException e) {
				MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVURLInvalid"));
				Object[] msgArgs = {masterKeyPath};
				throw new SQLServerException(null , form.format(msgArgs) , null, 0 , false);    
			}

			// A valid URI.
			// Check if it is pointing to AKV.
			if(!parsedUri.getHost().toLowerCase().endsWith(azureKeyVaultDomainName)){
				// Return an error indicating that the AKV url is invalid.
				MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVMasterKeyPathInvalid"));
				Object[] msgArgs = {masterKeyPath};
				throw new SQLServerException(null , form.format(msgArgs) , null, 0 , false);   
			}
		}
	}

	/**
	 * Encrypt the text using specified Azure Key Vault key.
	 * 
	 * @param masterKeyPath - Azure Key Vault key url.
	 * @param encryptionAlgorithm - Encryption Algorithm.
	 * @param columnEncryptionKey - Plain text Column Encryption Key.
	 * @return Returns an encrypted blob or throws an exception if there are any errors.
	 * @throws SQLServerException
	 */
	private byte[] AzureKeyVaultWrap(String masterKeyPath, String encryptionAlgorithm, byte[] columnEncryptionKey) throws SQLServerException
	{
		if (null == columnEncryptionKey)
		{
			throw new SQLServerException(SQLServerException.getErrString("R_CEKNull"), null);
		}

		KeyOperationResult wrappedKey = null;
		try {
			wrappedKey = keyVaultClient.wrapKeyAsync(masterKeyPath, encryptionAlgorithm, columnEncryptionKey).get();
		} catch (InterruptedException | ExecutionException e) {
			throw new SQLServerException(SQLServerException.getErrString("R_EncryptCEKError"), null);
		}
		return wrappedKey.getResult();
	}

	/**
	 * Encrypt the text using specified Azure Key Vault key.
	 * 
	 * @param masterKeyPath - Azure Key Vault key url.
	 * @param encryptionAlgorithm - Encrypted Column Encryption Key.
	 * @param encryptedColumnEncryptionKey - Encrypted Column Encryption Key.
	 * @return Returns the decrypted plaintext Column Encryption Key or throws an exception if there are any errors.
	 * @throws SQLServerException
	 */
	private byte[] AzureKeyVaultUnWrap(String masterKeyPath, String encryptionAlgorithm, byte[] encryptedColumnEncryptionKey) throws SQLServerException
	{
		if (null == encryptedColumnEncryptionKey)
		{
			throw new SQLServerException(SQLServerException.getErrString("R_EncryptedCEKNull"), null);
		}

		if (0 == encryptedColumnEncryptionKey.length)
		{
			throw new SQLServerException(SQLServerException.getErrString("R_EmptyEncryptedCEK"), null);
		}

		KeyOperationResult unwrappedKey;
		try {
			unwrappedKey = keyVaultClient.unwrapKeyAsync(masterKeyPath, encryptionAlgorithm, encryptedColumnEncryptionKey).get();
		} catch (InterruptedException | ExecutionException e) {
			throw new SQLServerException(SQLServerException.getErrString("R_DecryptCEKError"), null);
		}
		return unwrappedKey.getResult();
	}

	/**
	 * Generates signature based on RSA PKCS#v1.5 scheme using a specified Azure Key Vault Key URL. 
	 * 
	 * @param dataToSign - Text to sign.
	 * @param masterKeyPath - Azure Key Vault key url.
	 * @return Signature
	 * @throws SQLServerException
	 */
	private byte[] AzureKeyVaultSignHashedData(byte[] dataToSign, String masterKeyPath) throws SQLServerException
	{
		assert((null != dataToSign) && (0 != dataToSign.length));

		KeyOperationResult signedData = null;
		try {
			signedData = keyVaultClient.signAsync(masterKeyPath, JsonWebKeySignatureAlgorithm.RS256, dataToSign).get();
		} catch (InterruptedException | ExecutionException e) {
			throw new SQLServerException(SQLServerException.getErrString("R_GenerateSignature"), null);
		}
		return signedData.getResult();
	}

	/**
	 * Verifies the given RSA PKCSv1.5 signature.
	 * 
	 * @param dataToVerify
	 * @param signature
	 * @param masterKeyPath - Azure Key Vault key url.
	 * @return true if signature is valid, false if it is not valid
	 * @throws SQLServerException
	 */
	private boolean AzureKeyVaultVerifySignature(byte[] dataToVerify, byte[] signature, String masterKeyPath) throws SQLServerException
	{
		assert((null != dataToVerify) && (0 != dataToVerify.length));
		assert((null != signature) && (0 != signature.length));

		boolean valid = false;
		try {
			valid = keyVaultClient.verifyAsync(masterKeyPath, JsonWebKeySignatureAlgorithm.RS256, dataToVerify, signature).get();
		} catch (InterruptedException | ExecutionException e) {
			throw new SQLServerException(SQLServerException.getErrString("R_VerifySignature"), null);
		}

		return valid;
	}

	/**
	 * Gets the public Key size in bytes
	 * 
	 * @param masterKeyPath - Azure Key Vault Key path
	 * @return Key size in bytes
	 * @throws SQLServerException when an error occurs
	 */
	private int getAKVKeySize(String masterKeyPath) throws SQLServerException
	{

		KeyBundle retrievedKey = null;
		try {
			retrievedKey = keyVaultClient.getKeyAsync(masterKeyPath).get();
		} catch (InterruptedException | ExecutionException e) {
			throw new SQLServerException(SQLServerException.getErrString("R_GetAKVKeySize"), null);
		}

		if (!retrievedKey.getKey().getKty().equalsIgnoreCase("RSA") &&
				!retrievedKey.getKey().getKty().equalsIgnoreCase("RSA-HSM"))
		{
			MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NonRSAKey"));
			Object[] msgArgs = {retrievedKey.getKey().getKty()};
			throw new SQLServerException(null , form.format(msgArgs) , null, 0 , false);   
		}

		return retrievedKey.getKey().getN().length;
	}
}
