package org.mule.module.xmlsecurity.keyinfo;

import org.mule.module.xmlsecurity.algorithms.CanonicalizationAlgorithm;
import org.mule.module.xmlsecurity.algorithms.SignatureMethodAlgorithm;

import javax.xml.crypto.dsig.Reference;
import javax.xml.crypto.dsig.SignedInfo;
import javax.xml.crypto.dsig.keyinfo.KeyInfo;
import javax.xml.crypto.dsig.keyinfo.KeyInfoFactory;
import javax.xml.crypto.dsig.keyinfo.KeyValue;
import javax.xml.crypto.dsig.spec.C14NMethodParameterSpec;
import java.security.Key;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.util.Collections;

import static org.mule.module.xmlsecurity.Signer.fac;

/**
 * <p>Generates a public/private key pair</p>
 *
 * @author fernando.federico@mulesource.com
 */
public class KeyPairInfoProvider implements KeyInfoProvider{

    private KeyPair keyPair;
    private CanonicalizationAlgorithm canonicalizationAlgorithm;
    private String keyPairAlgorithm;

    /**
     * Constructor
     * @param canonicalizationAlgorithm The canonicalization algorithm
     * @param keyPairAlgorithm The algorithm used to create the key pair (DSA / RSA)
     */
    public KeyPairInfoProvider(CanonicalizationAlgorithm canonicalizationAlgorithm, String keyPairAlgorithm) {
        this.canonicalizationAlgorithm = canonicalizationAlgorithm;
        this.keyPairAlgorithm = keyPairAlgorithm;
        try {
            this.keyPair = keyPair();
        } catch (NoSuchAlgorithmException e) {
            throw new RuntimeException("Could not getReference DSA key pair", e);
        }
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public KeyInfo getKeyInfo() {
        try{
            KeyInfoFactory kif = fac.getKeyInfoFactory();
            KeyValue kv = kif.newKeyValue(keyPair.getPublic());

            return kif.newKeyInfo(Collections.singletonList(kv));
        }
        catch(Exception e)
        {
            throw new RuntimeException("Could not getReference DSA Key Info", e);
        }
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public Key getSecretKey() {
        return keyPair.getPrivate();
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public SignedInfo getSignedInfo(Reference ref) {
        try
        {
            return fac.newSignedInfo(
                    fac.newCanonicalizationMethod
                            (canonicalizationAlgorithm.getAlgorithm(),
                                    (C14NMethodParameterSpec) null),
                    fac.newSignatureMethod(getSignatureMethod(), null),
                    Collections.singletonList(ref));
        }
        catch(Exception e)
        {
            throw new RuntimeException("Error while creating DSA Signed Info", e);
        }

    }


    private KeyPair keyPair() throws NoSuchAlgorithmException {
        KeyPairGenerator kpg = KeyPairGenerator.getInstance(keyPairAlgorithm);
        kpg.initialize(512);
        return kpg.generateKeyPair();
    }

    public String getSignatureMethod() {
        if ( keyPairAlgorithm.equals("DSA") ) return SignatureMethodAlgorithm.DSA_SHA1.getAlgorithm();
        if ( keyPairAlgorithm.equals("RSA") ) return SignatureMethodAlgorithm.RSA_SHA1.getAlgorithm();

        throw new RuntimeException("Invalid algorithm");
    }
}
