package org.mule.module.xmlsecurity.validator.selector;

import javax.security.auth.x500.X500Principal;
import javax.xml.crypto.*;
import javax.xml.crypto.dsig.SignatureMethod;
import javax.xml.crypto.dsig.keyinfo.KeyInfo;
import javax.xml.crypto.dsig.keyinfo.X509Data;
import javax.xml.crypto.dsig.keyinfo.X509IssuerSerial;
import java.io.FileInputStream;
import java.io.IOException;
import java.security.Key;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.PublicKey;
import java.security.cert.CertSelector;
import java.security.cert.X509CertSelector;
import java.security.cert.X509Certificate;
import java.util.Enumeration;
import java.util.Iterator;

/**
 * <p>Used to select X509 keys </p>
 *
 * @author fernando.federico@mulesource.com
 */
public class X509KeySelector extends KeySelector {

    private KeyStore ks;

    /**
     * Constructor
     *
     * @param keystore The Keystore path
     * @param password the Keystore password
     */
    public X509KeySelector(String keystore, String password) {

        try{
            ks = KeyStore.getInstance("JKS");
            ks.load(new FileInputStream(keystore), password.toCharArray());

        }
        catch (Exception e)
        {
            throw new RuntimeException("Could not read the keystore", e);
        }
    }

    public KeySelectorResult select(KeyInfo keyInfo,
                                    KeySelector.Purpose purpose, AlgorithmMethod method,
                                    XMLCryptoContext context) throws KeySelectorException {

        SignatureMethod sm = (SignatureMethod) method;

        try {
            if (keyInfo == null || ks.size() == 0) {
                return new SimpleKeySelectorResult(null);
            }

            Iterator i = keyInfo.getContent().iterator();
            while (i.hasNext()) {
                XMLStructure kiType = (XMLStructure) i.next();
                if (kiType instanceof X509Data) {
                    X509Data xd = (X509Data) kiType;
                    SimpleKeySelectorResult ksr = x509DataSelect(xd, sm);
                    if (ksr != null) {
                        return ksr;
                    }
                }
            }
        } catch (KeyStoreException kse) {
            throw new KeySelectorException(kse);
        }

        return new SimpleKeySelectorResult(null);
    }

    /**
     * Searches the specified keystore for a certificate that matches the
     * criteria specified in the CertSelector.
     *
     * @return a SimpleKeySelectorResult containing the cert's public key if there
     *   is a match; otherwise null
     */
    private SimpleKeySelectorResult keyStoreSelect(CertSelector cs)
            throws KeyStoreException {
        Enumeration aliases = ks.aliases();
        while (aliases.hasMoreElements()) {
            String alias = (String) aliases.nextElement();
            java.security.cert.Certificate cert = ks.getCertificate(alias);
            if (cert != null && cs.match(cert)) {
                return new SimpleKeySelectorResult(cert.getPublicKey());
            }
        }
        return null;
    }

    /**
     * Searches the specified keystore for a certificate that matches the
     * specified X509Certificate and contains a public key that is compatible
     * with the specified SignatureMethod.
     *
     * @return a SimpleKeySelectorResult containing the cert's public key if there
     *   is a match; otherwise null
     */
    private SimpleKeySelectorResult certSelect(X509Certificate xcert,
                                         SignatureMethod sm) throws KeyStoreException {
        // skip non-signer certs
        boolean[] keyUsage = xcert.getKeyUsage();
        if (keyUsage != null && keyUsage[0] == false) {
            return null;
        }
        String alias = ks.getCertificateAlias(xcert);
        if (alias != null) {
            PublicKey pk = ks.getCertificate(alias).getPublicKey();
            // make sure algorithm is compatible with method
            if (AlgorithmEqualityChecker.algEquals(sm.getAlgorithm(), pk.getAlgorithm())) {
                return new SimpleKeySelectorResult(pk);
            }
        }
        return null;
    }

    /**
     * Returns an OID of a public-key algorithm compatible with the specified
     * signature algorithm URI.
     */
    private String getPKAlgorithmOID(String algURI) {
        if (algURI.equalsIgnoreCase(SignatureMethod.DSA_SHA1)) {
            return "1.2.840.10040.4.1";
        } else if (algURI.equalsIgnoreCase(SignatureMethod.RSA_SHA1)) {
            return "1.2.840.113549.1.1";
        } else {
            return null;
        }
    }

    /**
     * A simple SimpleKeySelectorResult containing a public key.
     */
    private static class SimpleKeySelectorResult implements KeySelectorResult{
        private final Key key;
        SimpleKeySelectorResult(Key key) { this.key = key; }
        public Key getKey() { return key; }
    }

    /**
     * Searches the specified keystore for a certificate that matches an
     * entry of the specified X509Data and contains a public key that is
     * compatible with the specified SignatureMethod.
     *
     * @return a SimpleKeySelectorResult containing the cert's public key if there
     *   is a match; otherwise null
     */
    private SimpleKeySelectorResult x509DataSelect(X509Data xd, SignatureMethod sm)
            throws KeyStoreException, KeySelectorException {

        // convert signature algorithm to compatible public-key alg OID
        String algOID = getPKAlgorithmOID(sm.getAlgorithm());

        SimpleKeySelectorResult ksr = null;
        Iterator xi = xd.getContent().iterator();
        while (xi.hasNext()) {
            ksr = null;
            Object o = xi.next();
            // check X509Certificate
            if (o instanceof X509Certificate) {
                X509Certificate xcert = (X509Certificate) o;
                ksr = certSelect(xcert, sm);
                // check X509IssuerSerial
            } else if (o instanceof X509IssuerSerial) {
                X509IssuerSerial xis = (X509IssuerSerial) o;
                X509CertSelector xcs = new X509CertSelector();
                try {
                    xcs.setSubjectPublicKeyAlgID(algOID);
                    xcs.setSerialNumber(xis.getSerialNumber());
                    xcs.setIssuer(new X500Principal
                            (xis.getIssuerName()).getName());
                } catch (IOException ioe) {
                    throw new KeySelectorException(ioe);
                }
                ksr = keyStoreSelect(xcs);
                // check X509SubjectName
            } else if (o instanceof String) {
                String sn = (String) o;
                X509CertSelector xcs = new X509CertSelector();
                try {
                    xcs.setSubjectPublicKeyAlgID(algOID);
                    xcs.setSubject(new X500Principal(sn).getName());
                } catch (IOException ioe) {
                    throw new KeySelectorException(ioe);
                }
                ksr = keyStoreSelect(xcs);
                // check X509SKI
            } else if (o instanceof byte[]) {
                byte[] ski = (byte[]) o;
                X509CertSelector xcs = new X509CertSelector();
                try {
                    xcs.setSubjectPublicKeyAlgID(algOID);
                } catch (IOException ioe) {
                    throw new KeySelectorException(ioe);
                }
                // DER-encode ski - required by X509CertSelector
                byte[] encodedSki = new byte[ski.length+2];
                encodedSki[0] = 0x04; // OCTET STRING tag value
                encodedSki[1] = (byte) ski.length; // length
                System.arraycopy(ski, 0, encodedSki, 2, ski.length);
                xcs.setSubjectKeyIdentifier(encodedSki);
                ksr = keyStoreSelect(xcs);
                // check X509CRL
                // not supported: should use CertPath API
            } else {
                // skip all other entries
                continue;
            }
            if (ksr != null) {
                return ksr;
            }
        }
        return null;
    }
}