package com.nimbusds.openid.connect.provider.jwksetgen;


import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.text.ParseException;
import java.util.LinkedList;
import java.util.List;
import java.util.function.Consumer;

import com.thetransactioncompany.json.pretty.PrettyJson;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.Options;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.jwk.*;
import com.nimbusds.jose.jwk.gen.ECKeyGenerator;
import com.nimbusds.jose.jwk.gen.OctetKeyPairGenerator;
import com.nimbusds.jose.jwk.gen.OctetSequenceKeyGenerator;
import com.nimbusds.jose.jwk.gen.RSAKeyGenerator;
import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jose.util.JSONObjectUtils;


/**
 * JWK set generator for Connect2id servers.
 */
public class JWKSetGenerator {
	
	
	/**
	 * The RSA key bit size.
	 */
	public static final int RSA_KEY_BIT_SIZE = 2048;
	
	
	/**
	 * The AES key bit size.
	 */
	public static final int AES_KEY_BIT_SIZE = 128;
	
	
	/**
	 * The HMAC SHA key bit size.
	 */
	public static final int HMAC_SHA_KEY_BIT_SIZE = 256;
	
	
	/**
	 * The subject AES SIV key bit size.
	 */
	public static final int SUBJECT_AES_SIV_KEY_BIT_SIZE = 256;
	
	
	/**
	 * The refresh token AES SIV key bit size.
	 */
	public static final int REFRESH_TOKEN_AES_SIV_KEY_BIT_SIZE = 256;
	
	
	/**
	 * Generates a 2048 bit RSA signing key with the specified key ID.
	 *
	 * @param kid The key ID, {@code null} if not specified.
	 *
	 * @return The RSA key pair.
	 */
	public static RSAKey generateSigningRSAKey(final String kid)
		throws JOSEException {
		
		return new RSAKeyGenerator(RSA_KEY_BIT_SIZE)
			.keyID(kid)
			.keyUse(KeyUse.SIGNATURE)
			.generate();
	}
	
	
	/**
	 * Generates a 2048 bit RSA encryption key with the specified key ID.
	 *
	 * @param kid The key ID, {@code null} if not specified.
	 *
	 * @return The RSA key pair.
	 */
	public static RSAKey generateEncryptionRSAKey(final String kid)
		throws JOSEException {
		
		return new RSAKeyGenerator(RSA_KEY_BIT_SIZE)
			.keyID(kid)
			.keyUse(KeyUse.ENCRYPTION)
			.generate();
	}
	
	
	/**
	 * Generates an EC signing key with the specified curve and key ID.
	 *
	 * @param crv The curve. Must not be {@code null}.
	 * @param kid The key ID, {@code null} if not specified.
	 *
	 * @return The EC key pair.
	 */
	public static ECKey generateSigningECKey(final Curve crv, final String kid)
		throws JOSEException {
		
		return new ECKeyGenerator(crv)
			.keyID(kid)
			.keyUse(KeyUse.SIGNATURE)
			.generate();
	}
	
	
	/**
	 * Generates an EC encryption key with the specified curve and key ID.
	 *
	 * @param crv The curve. Must not be {@code null}.
	 * @param kid The key ID, {@code null} if not specified.
	 *
	 * @return The EC key pair.
	 */
	public static ECKey generateEncryptionECKey(final Curve crv, final String kid)
		throws JOSEException {
		
		return new ECKeyGenerator(crv)
			.keyID(kid)
			.keyUse(KeyUse.ENCRYPTION)
			.generate();
	}
	
	
	/**
	 * Generates an Ed25519 signing key with the specified key ID.
	 *
	 * @param kid The key ID, {@code null} if not specified.
	 *
	 * @return The EC key pair.
	 */
	public static OctetKeyPair generateSigningEd25519Key(final String kid)
		throws JOSEException {
		
		return new OctetKeyPairGenerator(Curve.Ed25519)
			.keyID(kid)
			.keyUse(KeyUse.SIGNATURE)
			.generate();
	}
	
	
	/**
	 * Generates a 128 bit AES encryption key with the specified key ID.
	 *
	 * @param kid The key ID, {@code null} if not specified.
	 *
	 * @return The AES key.
	 */
	public static OctetSequenceKey generateEncryptionAESKey(final String kid)
		throws JOSEException {
		
		return new OctetSequenceKeyGenerator(AES_KEY_BIT_SIZE)
			.keyID(kid)
			.keyUse(KeyUse.ENCRYPTION)
			.generate();
	}
	
	
	/**
	 * Generates a 256 bit HMAC SHA key with key ID "hmac".
	 *
	 * @return The HMAC SHA key.
	 */
	public static OctetSequenceKey generateHMACSHA256Key()
		throws JOSEException {
		
		return new OctetSequenceKeyGenerator(HMAC_SHA_KEY_BIT_SIZE)
			.keyID("hmac")
			.keyUse(KeyUse.SIGNATURE)
			.generate();
	}
	
	
	/**
	 * JWK matcher for the 256 bit HMAC SHA key with key ID "hmac".
	 */
	public static final JWKMatcher HMAC_SHA256_KEY_MATCHER = new JWKMatcher.Builder()
		.keyType(KeyType.OCT)
		.keySize(HMAC_SHA_KEY_BIT_SIZE)
		.keyID("hmac")
		.keyUse(KeyUse.SIGNATURE)
		.build();
	
	
	/**
	 * Generates a 256 bit subject encryption key (intended for AES SIV
	 * mode) with key ID "subject-encrypt".
	 *
	 * @return The subject encryption key.
	 */
	public static OctetSequenceKey generateSubjectEncryptionKey()
		throws JOSEException {
		
		return new OctetSequenceKeyGenerator(SUBJECT_AES_SIV_KEY_BIT_SIZE)
			.keyID("subject-encrypt")
			.keyUse(KeyUse.ENCRYPTION)
			.generate();
	}
	
	
	/**
	 * JWK matcher for the 256 bit subject encryption key (intended for AES
	 * SIV mode) with key ID "subject-encrypt".
	 */
	public static final JWKMatcher SUBJECT_ENCRYPTION_KEY_MATCHER = new JWKMatcher.Builder()
		.keyType(KeyType.OCT)
		.keySize(SUBJECT_AES_SIV_KEY_BIT_SIZE)
		.keyID("subject-encrypt")
		.keyUse(KeyUse.ENCRYPTION)
		.build();
	
	
	/**
	 * Generates a 256 bit refresh token encryption key (intended for AES
	 * SIV mode) with key ID "refresh-token-encrypt".
	 *
	 * @return The refresh token encryption key.
	 */
	public static OctetSequenceKey generateRefreshTokenEncryptionKey()
		throws JOSEException {
		
		return new OctetSequenceKeyGenerator(REFRESH_TOKEN_AES_SIV_KEY_BIT_SIZE)
			.keyID("refresh-token-encrypt")
			.keyUse(KeyUse.ENCRYPTION)
			.generate();
	}
	
	
	/**
	 * JWK matcher for the 256 bit refresh token encryption key (intended
	 * for AES SIV mode) with key ID "refresh-token-encrypt".
	 */
	public static final JWKMatcher REFRESH_TOKEN_ENCRYPTION_KEY_MATCHER = new JWKMatcher.Builder()
		.keyType(KeyType.OCT)
		.keySize(REFRESH_TOKEN_AES_SIV_KEY_BIT_SIZE)
		.keyID("refresh-token-encrypt")
		.keyUse(KeyUse.ENCRYPTION)
		.build();
	
	
	/**
	 * Generates a new set of rotating signature and encryption keys for a
	 * Connect2id server.
	 *
	 * @param reservedKeyIDs   The reserved key IDs, empty if none.
	 * @param eventMessageSink Optional sink for event messages,
	 *                         {@code null} if not specified.
	 *
	 * @return The generated rotating keys.
	 */
	public List<JWK> generateRotatingKeys(final KeyIDs reservedKeyIDs, final Consumer<String> eventMessageSink)
		throws JOSEException {
		
		List<JWK> keys = new LinkedList<>();
		
		KeyIDs keyIDs = new KeyIDs();
		keyIDs.addAll(reservedKeyIDs);
		
		RSAKey rsaKey = generateSigningRSAKey(keyIDs.addRandomUniqueKeyID());
		keys.add(rsaKey);
		if (eventMessageSink != null) {
			eventMessageSink.accept("Generated new signing RSA " + RSA_KEY_BIT_SIZE + " bit key with ID " + rsaKey.getKeyID());
		}
		
		ECKey ecKey = generateSigningECKey(Curve.P_256, keyIDs.addRandomUniqueKeyID());
		keys.add(ecKey);
		if (eventMessageSink != null) {
			eventMessageSink.accept("Generated new signing EC " + ecKey.getCurve() + " key with ID " + ecKey.getKeyID());
		}
		
		ecKey = generateSigningECKey(Curve.P_384, keyIDs.addRandomUniqueKeyID());
		keys.add(ecKey);
		if (eventMessageSink != null) {
			eventMessageSink.accept("Generated new signing EC " + ecKey.getCurve() + " key with ID " + ecKey.getKeyID());
		}
		
		ecKey = generateSigningECKey(Curve.P_521, keyIDs.addRandomUniqueKeyID());
		keys.add(ecKey);
		if (eventMessageSink != null) {
			eventMessageSink.accept("Generated new signing EC " + ecKey.getCurve() + " key with ID " + ecKey.getKeyID());
		}
		
		OctetKeyPair okp = generateSigningEd25519Key(keyIDs.addRandomUniqueKeyID());
		keys.add(okp);
		if (eventMessageSink != null) {
			eventMessageSink.accept("Generated new signing " + okp.getCurve() + " key with ID " + okp.getKeyID());
		}
		
		rsaKey = generateEncryptionRSAKey(keyIDs.addRandomUniqueKeyID());
		keys.add(rsaKey);
		if (eventMessageSink != null) {
			eventMessageSink.accept("Generated new encryption RSA " + RSA_KEY_BIT_SIZE + " bit key with ID " + rsaKey.getKeyID());
		}
		
		ecKey = generateEncryptionECKey(Curve.P_256, keyIDs.addRandomUniqueKeyID());
		keys.add(ecKey);
		if (eventMessageSink != null) {
			eventMessageSink.accept("Generated new encryption EC " + ecKey.getCurve() + " key with ID " + ecKey.getKeyID());
		}
		
		ecKey = generateEncryptionECKey(Curve.P_384, keyIDs.addRandomUniqueKeyID());
		keys.add(ecKey);
		if (eventMessageSink != null) {
			eventMessageSink.accept("Generated new encryption EC " + ecKey.getCurve() + " key with ID " + ecKey.getKeyID());
		}
		
		ecKey = generateEncryptionECKey(Curve.P_521, keyIDs.addRandomUniqueKeyID());
		keys.add(ecKey);
		if (eventMessageSink != null) {
			eventMessageSink.accept("Generated new encryption EC " + ecKey.getCurve() + " key with ID " + ecKey.getKeyID());
		}
		
		OctetSequenceKey secretKey = generateEncryptionAESKey(keyIDs.addRandomUniqueKeyID());
		keys.add(secretKey);
		if (eventMessageSink != null) {
			eventMessageSink.accept("Generated new encryption AES " + AES_KEY_BIT_SIZE + " bit key with ID " + secretKey.getKeyID());
		}
		
		return keys;
	}
	
	
	/**
	 * Generates a new set of permanent keys for a Connect2id server.
	 *
	 * @param eventMessageSink Optional sink for event messages,
	 *                         {@code null} if not specified.
	 *
	 * @return The generated keys.
	 */
	public List<JWK> generatePermanentKeys(final Consumer<String> eventMessageSink)
		throws JOSEException {
		
		List<JWK> keys = new LinkedList<>();
		
		OctetSequenceKey hmacKey = generateHMACSHA256Key();
		keys.add(hmacKey);
		if (eventMessageSink != null) {
			eventMessageSink.accept("Generated new HMAC SHA " + HMAC_SHA_KEY_BIT_SIZE + " bit key with ID " + hmacKey.getKeyID());
		}
		
		OctetSequenceKey subjectKey = generateSubjectEncryptionKey();
		keys.add(subjectKey);
		if (eventMessageSink != null) {
			eventMessageSink.accept("Generated new subject encryption AES SIV " + SUBJECT_AES_SIV_KEY_BIT_SIZE + " bit key with ID " + subjectKey.getKeyID());
		}
		
		OctetSequenceKey refreshTokenKey = generateRefreshTokenEncryptionKey();
		keys.add(refreshTokenKey);
		if (eventMessageSink != null) {
			eventMessageSink.accept("Generated new refresh token encryption AES SIV " + REFRESH_TOKEN_AES_SIV_KEY_BIT_SIZE + " bit key with ID " + refreshTokenKey.getKeyID());
		}
		
		return keys;
	}
	
	
	/**
	 * Generates the missing permanent keys for a Connect2id server not
	 * found in the specified JWK set.
	 *
	 * @param jwkSet           The JWK set.
	 * @param eventMessageSink Optional sink for event messages,
	 *                         {@code null} if not specified.
	 *
	 * @return The generated missing permanent keys, empty list if none.
	 */
	public List<JWK> generateMissingPermanentKeys(final JWKSet jwkSet, final Consumer<String> eventMessageSink)
		throws JOSEException {
		
		List<JWK> keys = new LinkedList<>();
		
		if (new JWKSelector(HMAC_SHA256_KEY_MATCHER).select(jwkSet).isEmpty()) {
			OctetSequenceKey hmacKey = generateHMACSHA256Key();
			keys.add(hmacKey);
			if (eventMessageSink != null) {
				eventMessageSink.accept("Generated missing HMAC SHA " + HMAC_SHA_KEY_BIT_SIZE + " bit key with ID " + hmacKey.getKeyID());
			}
		}
		
		if (new JWKSelector(SUBJECT_ENCRYPTION_KEY_MATCHER).select(jwkSet).isEmpty()) {
			OctetSequenceKey subjectKey = generateSubjectEncryptionKey();
			keys.add(subjectKey);
			if (eventMessageSink != null) {
				eventMessageSink.accept("Generated missing subject encryption AES SIV " + SUBJECT_AES_SIV_KEY_BIT_SIZE + " bit key with ID " + subjectKey.getKeyID());
			}
		}
		
		if (new JWKSelector(REFRESH_TOKEN_ENCRYPTION_KEY_MATCHER).select(jwkSet).isEmpty()) {
			OctetSequenceKey refreshTokenKey = generateRefreshTokenEncryptionKey();
			keys.add(refreshTokenKey);
			if (eventMessageSink != null) {
				eventMessageSink.accept("Generated missing refresh token encryption AES SIV " + REFRESH_TOKEN_AES_SIV_KEY_BIT_SIZE + " bit key with ID " + refreshTokenKey.getKeyID());
			}
		}
		
		return keys;
	}
	
	
	/**
	 * Generates a new JWK set for a Connect2id server.
	 *
	 * @param eventMessageSink Optional sink for event messages,
	 *                         {@code null} if not specified.
	 *
	 * @return The JWK set.
	 */
	public JWKSet generate(final Consumer<String> eventMessageSink)
		throws JOSEException {
		
		List<JWK> keys = new LinkedList<>();
		keys.addAll(generateRotatingKeys(new KeyIDs(), eventMessageSink));
		keys.addAll(generatePermanentKeys(eventMessageSink));
		return new JWKSet(keys);
	}
	
	
	/**
	 * A generates a new set of signing and encryption keys and prefixes
	 * them to the specified Connect2id server JWK set.
	 *
	 * @param oldJWKSet        The Connect2id server JWK set. Must not be
	 *                         {@code null}.
	 * @param eventMessageSink Optional sink for event messages,
	 *                         {@code null} if not specified.
	 *
	 * @return The updated JWK set.
	 */
	public JWKSet generateAndPrefixNewKeys(final JWKSet oldJWKSet, final Consumer<String> eventMessageSink)
		throws Exception {
		
		// Prefix so Connect2id server can roll over to new keys
		List<JWK> keys = generateRotatingKeys(new KeyIDs(oldJWKSet), eventMessageSink);
		JWKSet jwkSet = prefixKeys(oldJWKSet, keys);
		
		if (eventMessageSink != null) {
			eventMessageSink.accept("Prefixed newly generated keys to existing JWK set");
		}
		
		return jwkSet;
	}
	
	
	/**
	 * Inserts the specified keys at the beginning of a JWK set.
	 *
	 * @param jwkSet The JWK set.
	 * @param keys   The keys to insert.
	 *
	 * @return The updated JWK set.
	 */
	private static JWKSet prefixKeys(final JWKSet jwkSet, final List<JWK> keys) {
		
		List<JWK> updatedKeyList = new LinkedList<>();
		updatedKeyList.addAll(keys);
		updatedKeyList.addAll(jwkSet.getKeys());
		return new JWKSet(updatedKeyList);
	}
	
	
	/**
	 * Appends the specified keys at the end of a JWK set.
	 *
	 * @param jwkSet The JWK set.
	 * @param keys   The keys to append.
	 *
	 * @return The updated JWK set.
	 */
	private static JWKSet postfixKeys(final JWKSet jwkSet, final List<JWK> keys) {
		
		List<JWK> updatedKeyList = new LinkedList<>();
		updatedKeyList.addAll(jwkSet.getKeys());
		updatedKeyList.addAll(keys);
		return new JWKSet(updatedKeyList);
	}
	
	
	/**
	 * Prints the CLI usage to standard output.
	 */
	static void printUsage() {
		System.out.println("Usage:");
		System.out.println("1) Generate a new Connect2id server JWK set: ");
		System.out.println("   java -jar jwkset-gen.jar jwkSet.json");
		System.out.println("2) Generate a new set of rotating keys and add to an existing Connect2id server\n" +
				   "   JWK set; if a required permanent key is found missing it will also be\n" +
			           "   generated and appended to the JWK set (useful when upgrading to a newer\n" +
				   "   Connect2id server release requiring a new type of permanent key):");
		System.out.println("   java -jar jwkset-gen.jar oldJWKSet.json newJWKSet.json");
		System.out.println("3) BASE64URL encode the output JWK set: ");
		System.out.println("   java -jar -b64 jwkset-gen.jar oldJWKSet.json newJWKSet.json.b64");
		System.out.println("4) Generate a new OpenID Connect Federation 1.0 entity JWK set: ");
		System.out.println("   java -jar jwkset-gen.jar -federation federationJWKSet.json");
		System.out.println("5) Generate a new set of rotating keys and add to an existing OpenID Connect\n" +
				   "   Federation 1.0 entity JWK set: ");
		System.out.println("   java -jar jwkset-gen.jar -federation oldFedJWKSet.json newFedJWKSet.json");
	}
	
	
	/**
	 * Console method for generating a new Connect2id server JWK set, or
	 * updating an existing JWK set with new signing and encryption keys.
	 *
	 * @param args The command line arguments.
	 */
	public static void main(final String[] args) {
		
		String softwareVersion = JWKSetGenerator.class.getPackage().getImplementationVersion();
		
		System.out.println("JWK set generator " + (softwareVersion != null ? "v" + softwareVersion + " " : "") + "for Connect2id server v6.x+");
		
		// create Options object
		Options options = new Options();
		options.addOption("b64", false, "BASE64URL encode the output JWK set");
		options.addOption("federation", false, "Create federation entity JWK set");
		
		CommandLineParser parser = new DefaultParser();
		CommandLine commandLine;
		try {
			commandLine = parser.parse( options, args);
		} catch (org.apache.commons.cli.ParseException e) {
			System.err.println( "Command line parse error: " + e.getMessage());
			printUsage();
			return;
		}
		
		final boolean b64Encode = commandLine.hasOption("b64");
		final boolean federation = commandLine.hasOption("federation");
		
		final List<String> argList = commandLine.getArgList();
		
		final File oldJWKSetFile;
		final File newJWKSetFile;
		
		if (argList.size() == 2) {
			oldJWKSetFile = new File(argList.get(0));
			newJWKSetFile = new File(argList.get(1));
		} else if (argList.size() == 1) {
			oldJWKSetFile = null;
			newJWKSetFile = new File(argList.get(0));
		} else {
			printUsage();
			return;
		}
		
		JWKSet oldJWKSet = null;
		
		if (oldJWKSetFile != null) {
			try {
				oldJWKSet = JWKSet.load(oldJWKSetFile);
			} catch (IOException | ParseException e) {
				System.err.println("Couldn't read old JWK set file: " + e.getMessage());
				return;
			}
		}
		
		final Consumer<String> eventMessageSink = new Consumer<String>() {
			private final NumberedEventPrinter printer = new NumberedEventPrinter();
			@Override
			public void accept(final String eventMessage) {
				printer.print(eventMessage);
			}
		};
		
		JWKSet newJWKSet;
		try {
			if (oldJWKSet == null) {
				// Generate entire new JWK set
				if (federation) {
					// Federation entity keys
					newJWKSet = new FederationJWKSetGenerator().generate(eventMessageSink);
				} else {
					// Connect2id server keys
					newJWKSet = new JWKSetGenerator().generate(eventMessageSink);
				}
			} else {
				// Prefix new rotating keys to existing JWK set
				if (federation) {
					// Federation entity keys
					newJWKSet = new FederationJWKSetGenerator().generateAndPrefixNewKeys(oldJWKSet, eventMessageSink);
				} else {
					// Connect2id server keys
					JWKSetGenerator gen = new JWKSetGenerator();
					newJWKSet = gen.generateAndPrefixNewKeys(oldJWKSet, eventMessageSink);
					
					// Check if any of the required permanent keys are missing after an upgrade to a newer
					// Connect2id server and generate them
					List<JWK> missingPermanentKeys = gen.generateMissingPermanentKeys(newJWKSet, eventMessageSink);
					if (! missingPermanentKeys.isEmpty()) {
						newJWKSet = postfixKeys(newJWKSet, missingPermanentKeys);
						eventMessageSink.accept("Appended generated permanent keys to existing JWK set");
					}
				}
			}
		} catch (Exception e) {
			System.err.println("Couldn't generate JWK key: " + e.getMessage());
			return;
		}
		
		String json = JSONObjectUtils.toJSONString(newJWKSet.toJSONObject(false));
		
		String output;
		
		if (b64Encode) {
			output = Base64URL.encode(json).toString();
		} else {
			try {
				output = new PrettyJson(PrettyJson.Style.COMPACT).parseAndFormat(json);
			} catch (ParseException e) {
				System.err.println("Couldn't format JSON: " + e.getMessage());
				return;
			}
		}
		
		try {
			PrintWriter writer = new PrintWriter(newJWKSetFile, "UTF-8");
			writer.write(output);
			writer.write("\n");
			writer.close();
		} catch (IOException e) {
			System.err.println("Couldn't write new JWK set file: " + e.getMessage());
		}
	}
}
