package shz.encrypt;

import shz.PRException;
import shz.msg.ServerFailure;

import javax.crypto.Cipher;
import java.io.Serializable;
import java.security.*;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;

public final class RsaEncipher implements Serializable {
    private static final String KEY_ALGORITHM = "RSA";
    private static final long serialVersionUID = 2157700228208207661L;
    private int keySize;
    private KeyPair keyPair;
    private static volatile RsaEncipher instance;

    private RsaEncipher(int keySize) {
        ServerFailure.IllegalStateException.requireNon(instance != null);
        initKeyPair(keySize);
    }

    public void initKeyPair(int keySize) {
        try {
            this.keySize = keySize;
            KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance(KEY_ALGORITHM);
            keyPairGenerator.initialize(this.keySize);
            keyPair = keyPairGenerator.generateKeyPair();
        } catch (Throwable t) {
            throw PRException.of(t);
        }
    }

    public void initKeyPair() {
        initKeyPair(keySize);
    }

    public static RsaEncipher getInstance(int keySize) {
        if (instance == null) {
            synchronized (RsaEncipher.class) {
                if (instance == null) instance = new RsaEncipher(keySize);
            }
        }
        return instance;
    }

    public static RsaEncipher getInstance() {
        return getInstance(1024);
    }

    private Object readResolve() {
        return instance;
    }

    public byte[] encryptByPublicKey(byte[] plaintext, byte[] publicKey) {
        try {
            X509EncodedKeySpec keySpec = new X509EncodedKeySpec(publicKey);
            KeyFactory keyFactory = KeyFactory.getInstance(KEY_ALGORITHM);
            PublicKey key = keyFactory.generatePublic(keySpec);
            Cipher cipher = Cipher.getInstance(keyFactory.getAlgorithm());
            cipher.init(Cipher.ENCRYPT_MODE, key);
            return cipher.doFinal(plaintext);
        } catch (Throwable t) {
            throw PRException.of(t);
        }
    }

    public byte[] decryptByPrivateKey(byte[] ciphertext, byte[] privateKey) {
        try {
            PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(privateKey);
            KeyFactory keyFactory = KeyFactory.getInstance(KEY_ALGORITHM);
            PrivateKey key = keyFactory.generatePrivate(keySpec);
            Cipher cipher = Cipher.getInstance(keyFactory.getAlgorithm());
            cipher.init(Cipher.DECRYPT_MODE, key);
            return cipher.doFinal(ciphertext);
        } catch (Throwable t) {
            throw PRException.of(t);
        }
    }

    public byte[] encryptByPrivateKey(byte[] plaintext, byte[] privateKey) {
        try {
            PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(privateKey);
            KeyFactory keyFactory = KeyFactory.getInstance(KEY_ALGORITHM);
            PrivateKey key = keyFactory.generatePrivate(keySpec);
            Cipher cipher = Cipher.getInstance(keyFactory.getAlgorithm());
            cipher.init(Cipher.ENCRYPT_MODE, key);
            return cipher.doFinal(plaintext);
        } catch (Throwable t) {
            throw PRException.of(t);
        }
    }

    public byte[] decryptByPublicKey(byte[] ciphertext, byte[] publicKey) {
        try {
            X509EncodedKeySpec keySpec = new X509EncodedKeySpec(publicKey);
            KeyFactory keyFactory = KeyFactory.getInstance(KEY_ALGORITHM);
            PublicKey key = keyFactory.generatePublic(keySpec);
            Cipher cipher = Cipher.getInstance(keyFactory.getAlgorithm());
            cipher.init(Cipher.DECRYPT_MODE, key);
            return cipher.doFinal(ciphertext);
        } catch (Throwable t) {
            throw PRException.of(t);
        }
    }

    public KeyPair getKeyPair() {
        return keyPair;
    }
}