package com.nimbusds.openid.connect.provider.jwksetgen;


import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.security.InvalidAlgorithmParameterException;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.util.LinkedList;
import java.util.List;
import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;

import com.nimbusds.jose.jwk.*;
import com.thetransactioncompany.json.pretty.PrettyJson;


/**
 * JWK set generators for Connect2id servers.
 */
public class JWKSetGenerator {
	
	
	/**
	 * The RSA key bit size.
	 */
	public static final int RSA_KEY_BIT_SIZE = 2048;
	
	
	/**
	 * The AES key bit size.
	 */
	public static final int AES_KEY_BIT_SIZE = 128;
	
	
	/**
	 * Generates a 2048 bit RSA signing key with the specified key ID.
	 *
	 * @param kid The key ID, {@code null} if not specified.
	 *
	 * @return The RSA key pair.
	 */
	public static RSAKey generateSigningRSAKey(final String kid)
		throws NoSuchAlgorithmException {
		
		KeyPairGenerator keyPairGen = KeyPairGenerator.getInstance("RSA");
		keyPairGen.initialize(RSA_KEY_BIT_SIZE);
		KeyPair keyPair = keyPairGen.generateKeyPair();
		RSAPublicKey rsaPublicKey = (RSAPublicKey)keyPair.getPublic();
		RSAPrivateKey rsaPrivateKey = (RSAPrivateKey)keyPair.getPrivate();
		
		return new RSAKey.Builder(rsaPublicKey)
			.privateKey(rsaPrivateKey)
			.keyID(kid)
			.keyUse(KeyUse.SIGNATURE)
			.build();
	}
	
	
	/**
	 * Generates an EC signing key with the specified curve and key ID.
	 *
	 * @param crv The curve. Must not be {@code null}.
	 * @param kid The key ID, {@code null} if not specified.
	 *
	 * @return The EC key pair.
	 */
	public static ECKey generateSigningECKey(final Curve crv, final String kid)
		throws NoSuchAlgorithmException, InvalidAlgorithmParameterException {
		
		KeyPairGenerator keyPairGen = KeyPairGenerator.getInstance("EC");
		keyPairGen.initialize(crv.toECParameterSpec());
		KeyPair keyPair = keyPairGen.generateKeyPair();
		ECPublicKey ecPublicKey = (ECPublicKey)keyPair.getPublic();
		ECPrivateKey ecPrivateKey = (ECPrivateKey)keyPair.getPrivate();
		
		return new ECKey.Builder(crv, ecPublicKey)
			.privateKey(ecPrivateKey)
			.keyID(kid)
			.keyUse(KeyUse.SIGNATURE)
			.build();
	}
	
	
	/**
	 * Generates a 128 bit AES encryption key with the specified key ID.
	 *
	 * @param kid The key ID, {@code null} if not specified.
	 *
	 * @return The AES key.
	 */
	public static OctetSequenceKey generateEncryptionAESKey(final String kid)
		throws NoSuchAlgorithmException {
		
		KeyGenerator keyGen = KeyGenerator.getInstance("AES");
		keyGen.init(AES_KEY_BIT_SIZE);
		SecretKey aesKey = keyGen.generateKey();
		
		return new OctetSequenceKey.Builder(aesKey)
			.keyID(kid)
			.keyUse(KeyUse.ENCRYPTION)
			.build();
	}
	
	
	/**
	 * Generates a 256 bit HMAC SHA key with key ID "hmac".
	 *
	 * @return The HMAC SHA key.
	 */
	public static OctetSequenceKey generateHMACSHA256Key()
		throws NoSuchAlgorithmException {
		
		KeyGenerator keyGen = KeyGenerator.getInstance("HmacSha256");
		SecretKey hmacKey = keyGen.generateKey();
		
		return new OctetSequenceKey.Builder(hmacKey)
			.keyID("hmac")
			.keyUse(KeyUse.SIGNATURE)
			.build();
	}
	
	
	/**
	 * Generates a 256 bit subject encryption key (intended for AES SIV
	 * mode) with key ID "subject-encrypt".
	 *
	 * @return The subject encryption key.
	 */
	public static OctetSequenceKey generateSubjectEncryptionKey()
		throws NoSuchAlgorithmException {
		
		KeyGenerator keyGen = KeyGenerator.getInstance("AES");
		keyGen.init(256);
		SecretKey aesKey = keyGen.generateKey();
		
		return new OctetSequenceKey.Builder(aesKey)
			.keyID("subject-encrypt")
			.keyUse(KeyUse.ENCRYPTION)
			.build();
	}
	
	
	/**
	 * Generates a new set of rotating signature and encryption keys for a
	 * Connect2id server.
	 *
	 * @param reservedKeyIDs The reserved key IDs, empty if none.
	 * @param withMessage    If {@code true} a message will be printed to
	 *                       standard output.
	 *
	 * @return The generated rotating keys.
	 */
	public List<JWK> generateRotatingKeys(final KeyIDs reservedKeyIDs, final boolean withMessage)
		throws NoSuchAlgorithmException, InvalidAlgorithmParameterException {
		
		List<JWK> keys = new LinkedList<>();
		
		KeyIDs keyIDs = new KeyIDs();
		keyIDs.addAll(reservedKeyIDs);
		
		RSAKey rsaKey = generateSigningRSAKey(keyIDs.addRandomUniqueKeyID());
		keys.add(rsaKey);
		if (withMessage) {
			System.out.println("[1] Generated new signing RSA " + RSA_KEY_BIT_SIZE + " bit key with ID " + rsaKey.getKeyID());
		}
		
		ECKey ecKey = generateSigningECKey(Curve.P_256, keyIDs.addRandomUniqueKeyID());
		keys.add(ecKey);
		if (withMessage) {
			System.out.println("[2] Generated new signing EC " + ecKey.getCurve() + " key with ID " + ecKey.getKeyID());
		}
		
		ecKey = generateSigningECKey(Curve.P_384, keyIDs.addRandomUniqueKeyID());
		keys.add(ecKey);
		if (withMessage) {
			System.out.println("[3] Generated new signing EC " + ecKey.getCurve() + " key with ID " + ecKey.getKeyID());
		}
		
		ecKey = generateSigningECKey(Curve.P_521, keyIDs.addRandomUniqueKeyID());
		keys.add(ecKey);
		if (withMessage) {
			System.out.println("[4] Generated new signing EC " + ecKey.getCurve() + " key with ID " + ecKey.getKeyID());
		}
		
		OctetSequenceKey secretKey = generateEncryptionAESKey(keyIDs.addRandomUniqueKeyID());
		keys.add(secretKey);
		if (withMessage) {
			System.out.println("[5] Generated new encryption AES " + AES_KEY_BIT_SIZE + " bit key with ID " + secretKey.getKeyID());
		}
		
		return keys;
	}
	
	
	/**
	 * Generates a new set of permanent keys for a Connect2id server.
	 *
	 * @param withMessage If {@code true} a message will be printed to
	 *                    standard output.
	 *
	 * @return The generated keys.
	 */
	public List<JWK> generatePermanentKeys(final boolean withMessage)
		throws NoSuchAlgorithmException {
		
		List<JWK> keys = new LinkedList<>();
		
		OctetSequenceKey hmacKey = generateHMACSHA256Key();
		keys.add(hmacKey);
		if (withMessage) {
			System.out.println("[6] Generated new HMAC SHA key with ID " + hmacKey.getKeyID());
		}
		
		OctetSequenceKey subjectKey = generateSubjectEncryptionKey();
		keys.add(subjectKey);
		if (withMessage) {
			System.out.println("[7] Generated new subject AES SIV encryption key with ID " + subjectKey.getKeyID());
		}
		
		return keys;
	}
	
	
	/**
	 * Generates a new JWK set for a Connect2id server.
	 *
	 * @param withMessage If {@code true} a message will be printed to
	 *                    standard output.
	 *
	 * @return The JWK set.
	 */
	public JWKSet generate(final boolean withMessage)
		throws NoSuchAlgorithmException, InvalidAlgorithmParameterException {
		
		List<JWK> keys = new LinkedList<>();
		keys.addAll(generateRotatingKeys(new KeyIDs(), withMessage));
		keys.addAll(generatePermanentKeys(withMessage));
		return new JWKSet(keys);
	}
	
	
	/**
	 * A generates a new set of signing and encryption keys and prefixes
	 * them to the specified Connect2id server JWK set.
	 *
	 * @param oldJWKSet   The Connect2id server JWK set. Must not be
	 *                    {@code null}.
	 * @param withMessage If {@code true} a message will be printed to std
	 *                    output.
	 *
	 * @return The updated JWK set.
	 */
	public JWKSet generateAndPrefixNewKeys(final JWKSet oldJWKSet, final boolean withMessage)
		throws Exception {
		
		// Prefix so Connect2id server can roll over to new keys
		List<JWK> keys = new LinkedList<>();
		keys.addAll(generateRotatingKeys(new KeyIDs(oldJWKSet), withMessage));
		keys.addAll(oldJWKSet.getKeys());
		
		if (withMessage) {
			System.out.println("[6] Prefixed newly generated keys to existing JWK set");
		}
		
		return new JWKSet(keys);
	}
	
	
	/**
	 * Console method for generating a new Connect2id server JWK set, or
	 * updating an existing JWK set with new signing and encryption keys.
	 *
	 * @param args The command line arguments.
	 */
	public static void main(final String[] args) {
		
		System.out.println("JWK set generator for Connect2id server v6.x+");
		
		final File oldJWKSetFile;
		final File newJWKSetFile;
		
		if (args.length == 2) {
			oldJWKSetFile = new File(args[0]);
			newJWKSetFile = new File(args[1]);
		} else if (args.length == 1) {
			oldJWKSetFile = null;
			newJWKSetFile = new File(args[0]);
		} else {
			System.out.println("Usage:");
			System.out.println("1) To generate new Connect2id server JWK set: ");
			System.out.println("   java -jar jwkset-gen.jar jwkSet.json");
			System.out.println("2) To add new set of rotating keys to existing Connect2id server JWK set: ");
			System.out.println("   java -jar jwkset-gen.jar oldJWKSet.json newJWKSet.json");
			return;
		}
		
		JWKSet oldJWKSet = null;
		
		if (oldJWKSetFile != null) {
			try {
				oldJWKSet = JWKSet.load(oldJWKSetFile);
			} catch (IOException | ParseException e) {
				System.err.println("Couldn't read old JWK set file: " + e.getMessage());
				return;
			}
		}
		
		final JWKSet newJWKSet;
		final boolean withMessage = true;
		try {
			if (oldJWKSet == null) {
				newJWKSet = new JWKSetGenerator().generate(withMessage);
			} else {
				newJWKSet = new JWKSetGenerator().generateAndPrefixNewKeys(oldJWKSet, withMessage);
			}
		} catch (Exception e) {
			System.err.println("Couldn't generate JWK key: " + e.getMessage());
			return;
		}
		
		
		try {
			String json = new PrettyJson(PrettyJson.Style.COMPACT).parseAndFormat(newJWKSet.toJSONObject(false).toJSONString());
			PrintWriter writer = new PrintWriter(newJWKSetFile, "UTF-8");
			writer.write(json);
			writer.write("\n");
			writer.close();
		} catch (ParseException | IOException e) {
			System.err.println("Couldn't write new JWK set file: " + e.getMessage());
		}
	}
}
