/*************************************************************************
 *
 * ADOBE CONFIDENTIAL
 * __________________
 *
 *  Copyright 2012 Adobe Systems Incorporated
 *  All Rights Reserved.
 *
 * NOTICE:  All information contained herein is, and remains
 * the property of Adobe Systems Incorporated and its suppliers,
 * if any.  The intellectual and technical concepts contained
 * herein are proprietary to Adobe Systems Incorporated and its
 * suppliers and are protected by trade secret or copyright law.
 * Dissemination of this information or reproduction of this material
 * is strictly forbidden unless prior written permission is obtained
 * from Adobe Systems Incorporated.
 **************************************************************************/
package com.adobe.granite.auth.saml.util;

import com.adobe.granite.auth.saml.model.AbstractRequest;
import com.adobe.granite.auth.saml.model.AuthnRequest;
import com.adobe.granite.auth.saml.model.Issuer;
import com.adobe.granite.auth.saml.model.LogoutRequest;
import com.adobe.granite.auth.saml.model.LogoutResponse;
import com.adobe.granite.auth.saml.model.Message;
import com.adobe.granite.auth.saml.model.NameIdPolicy;
import com.adobe.granite.auth.saml.model.Status;
import com.adobe.granite.auth.saml.model.xml.SamlXmlConstants;
import org.apache.xml.security.transforms.Transforms;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;

import javax.xml.crypto.MarshalException;
import javax.xml.crypto.dsig.CanonicalizationMethod;
import javax.xml.crypto.dsig.DigestMethod;
import javax.xml.crypto.dsig.Reference;
import javax.xml.crypto.dsig.SignatureMethod;
import javax.xml.crypto.dsig.SignedInfo;
import javax.xml.crypto.dsig.Transform;
import javax.xml.crypto.dsig.XMLSignature;
import javax.xml.crypto.dsig.XMLSignatureException;
import javax.xml.crypto.dsig.XMLSignatureFactory;
import javax.xml.crypto.dsig.dom.DOMSignContext;
import javax.xml.crypto.dsig.spec.C14NMethodParameterSpec;
import javax.xml.crypto.dsig.spec.TransformParameterSpec;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.transform.OutputKeys;
import javax.xml.transform.Transformer;
import javax.xml.transform.TransformerConfigurationException;
import javax.xml.transform.TransformerException;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.dom.DOMSource;
import javax.xml.transform.stream.StreamResult;
import java.io.OutputStream;
import java.security.InvalidAlgorithmParameterException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.util.Collections;
import java.util.LinkedList;

/**
 * SamlWriter is responsible for writing saml data as xml.
 */
public class SamlWriter {

    private final static String DIGEST_METHOD = DigestMethod.SHA256;
    private final static String SIGNATURE_METHOD = SignatureMethod.RSA_SHA1;

    /**
     * DocumentBuilderFactory used to construct DocumentBuilder instances.
     */
    private DocumentBuilderFactory builderFactory;

    /**
     * Creates a new instance of SamlWriter.
     */
    public SamlWriter() {
        super();
        this.builderFactory = DocumentBuilderFactory.newInstance();
        this.builderFactory.setNamespaceAware(true);
    }

    /**
     * Marshals the given AuthnRequest as an xml document and writes it to the given output stream.
     *
     * @param message    AuthnRequest to marshal.
     * @param privateKey The private Key of this SP to sign the authentication request (may be null)
     * @param out        OutputStream to write the xml document to.
     * @throws SamlWriterException An error occurred while marshaling the request.
     */
    public void write(final Message message, final OutputStream out, final Key privateKey) throws SamlWriterException {
        final Document requestDoc;
        if (message instanceof AuthnRequest) {
            requestDoc = createRequestDocument((AuthnRequest) message, privateKey);
        } else if (message instanceof LogoutRequest) {
            requestDoc = createRequestDocument((LogoutRequest) message, privateKey);
        } else if (message instanceof LogoutResponse) {
            requestDoc = createResponseDocument((LogoutResponse) message, privateKey);
        } else {
            throw new RuntimeException("Messages of type " + message.getClass().getName() + " are not supported yet.");
        }
        TransformerFactory transfac = TransformerFactory.newInstance();
        Transformer trans;
        try {
            trans = transfac.newTransformer();
        } catch (final TransformerConfigurationException e) {
            throw new SamlWriterException("Unable to create a new Transformer instance", e);
        }
        trans.setOutputProperty(OutputKeys.INDENT, "no");
        final StreamResult result = new StreamResult(out);
        final DOMSource source = new DOMSource(requestDoc);
        try {
            trans.transform(source, result);
        } catch (TransformerException e) {
            throw new SamlWriterException("An error occurred writing xml to output stream", e);
        }
    }

    // -- implementation methods --

    /**
     * Creates a logout response document from a given LogoutResponse object
     *
     * @param response   LogoutResponse to create the document for
     * @param privateKey Private Key to sign the LogoutResponse
     * @return Document containing xml data create from the logoutResponse
     * @throws SamlWriterException An error occured while marshalling the response
     */
    protected Document createResponseDocument(final LogoutResponse response, final Key privateKey) throws SamlWriterException {
        DocumentBuilder builder;
        try {
            builder = this.builderFactory.newDocumentBuilder();
        } catch (final ParserConfigurationException e) {
            throw new RuntimeException(e);
        }
        final Document document = builder.newDocument();
        createLogoutResponseElement(response, document, privateKey);
        return document;
    }

    /**
     * Creates a request Document from a given request element.
     *
     * @param request    AuthnRequest to create the document for.
     * @param privateKey The private Key of this SP to sign the authentication request
     * @return Document containing xml data created from the request.
     * @throws SamlWriterException An error occurred while marshaling the request.
     */
    protected Document createRequestDocument(final AuthnRequest request, final Key privateKey) throws SamlWriterException {
        DocumentBuilder builder;
        try {
            builder = this.builderFactory.newDocumentBuilder();
        } catch (final ParserConfigurationException e) {
            throw new RuntimeException(e);
        }
        final Document document = builder.newDocument();
        createAuthnRequestElement(request, document, privateKey);
        return document;
    }

    /**
     * Creates a request Document from a given request element.
     *
     * @param request    LogoutRequest to create the document for.
     * @param privateKey The private Key of this SP to sign the authentication request
     * @return Document containing xml data created from the request.
     * @throws SamlWriterException An error occurred while marshaling the request.
     */
    protected Document createRequestDocument(final LogoutRequest request, final Key privateKey) throws SamlWriterException {
        DocumentBuilder builder;
        try {
            builder = this.builderFactory.newDocumentBuilder();
        } catch (final ParserConfigurationException e) {
            throw new RuntimeException(e);
        }
        final Document document = builder.newDocument();
        createLogoutRequestElement(request, document, privateKey);
        return document;
    }

    protected void createLogoutResponseElement(final LogoutResponse logoutResponse, final Node parent, final Key privateKey) throws SamlWriterException {
        final Document ownerDocument;
        if (parent instanceof Document) {
            ownerDocument = (Document) parent;
        } else {
            ownerDocument = parent.getOwnerDocument();
        }
        final Element logoutResponseElement = ownerDocument.createElementNS(SamlXmlConstants.SAML_PROTOCOL_NAMESPACE, "samlp:" + SamlXmlConstants.LOGOUT_RESPONSE_ELEMENT);
        logoutResponseElement.setAttributeNS("http://www.w3.org/2000/xmlns/", "xmlns:samlp", SamlXmlConstants.SAML_PROTOCOL_NAMESPACE);
        logoutResponseElement.setAttributeNS("http://www.w3.org/2000/xmlns/", "xmlns:saml", SamlXmlConstants.SAML_ASSERTION_NAMESPACE);
        if (logoutResponse.getId() != null) {
            logoutResponseElement.setAttribute(SamlXmlConstants.ID_ATTR, logoutResponse.getId());
            logoutResponseElement.setIdAttribute(SamlXmlConstants.ID_ATTR, true);
        }
        if (logoutResponse.getVersion() != null) {
            logoutResponseElement.setAttribute(SamlXmlConstants.VERSION_ATTR, logoutResponse.getVersion());
        }
        if (logoutResponse.getIssueInstant() != null) {
            logoutResponseElement.setAttribute(SamlXmlConstants.ISSUE_INSTANT_ATTR, SamlXmlConstants.XML_DATE_FORMATTER.print(logoutResponse.getIssueInstant().getTimeInMillis()));
        }
        if (logoutResponse.getDestination() != null) {
            logoutResponseElement.setAttribute(SamlXmlConstants.DESTINATION_ATTR, logoutResponse.getDestination());
        }
        if (logoutResponse.getIssuer() != null) {
            createIssuerElement(logoutResponse.getIssuer(), logoutResponseElement);
        }
        if (logoutResponse.getInResponseTo() != null) {
            logoutResponseElement.setAttribute(SamlXmlConstants.IN_RESPONSE_TO_ATTR, logoutResponse.getInResponseTo());
        }
        Element statusElement = null;
        if (logoutResponse.getStatus() != null) {
            statusElement = createStatusElement(logoutResponse.getStatus(), logoutResponseElement);
        }
        parent.appendChild(logoutResponseElement);
        signDocument(ownerDocument, statusElement, privateKey, logoutResponse.getId());
    }

    /**
     * Create a status element containing a statuscode element from the given status object
     *
     * @param status The status object to create the element from
     * @param parent The parent element to append the status element
     */
    protected Element createStatusElement(Status status, Element parent) {
        final Document ownerDoc = parent.getOwnerDocument();
        final Element statusElement = ownerDoc.createElementNS(SamlXmlConstants.SAML_PROTOCOL_NAMESPACE, "samlp:" + SamlXmlConstants.STATUS_ELEMENT);
        statusElement.setAttributeNS("http://www.w3.org/2000/xmlns/", "xmlns:samlp", SamlXmlConstants.SAML_PROTOCOL_NAMESPACE);
        final Element statusCodeElement = ownerDoc.createElementNS(SamlXmlConstants.SAML_PROTOCOL_NAMESPACE, "samlp:" + SamlXmlConstants.STATUS_CODE_ELEMENT);
        statusCodeElement.setAttributeNS("http://www.w3.org/2000/xmlns/", "xmlns:samlp", SamlXmlConstants.SAML_PROTOCOL_NAMESPACE);
        statusCodeElement.setAttribute(SamlXmlConstants.VALUE_ATTR, status.getStatusCode());
        statusElement.appendChild(statusCodeElement);
        parent.appendChild(statusElement);
        return statusElement;
    }

    /**
     * Creates an AuthnRequest element from the given AuthnRequest and adds it to the given parent node.
     *
     * @param authnRequest AuthnRequest to create the element for.
     * @param parent       Parent node to add the created element to.
     * @param privateKey   The private Key of this SP to sign the authentication request
     * @throws SamlWriterException An error occurred while marshaling the request.
     */
    protected void createAuthnRequestElement(final AuthnRequest authnRequest, final Node parent, final Key privateKey) throws SamlWriterException {
        final Document ownerDocument;
        if (parent instanceof Document) {
            ownerDocument = (Document) parent;
        } else {
            ownerDocument = parent.getOwnerDocument();
        }
        final Element authnRequestElement = ownerDocument.createElementNS(SamlXmlConstants.SAML_PROTOCOL_NAMESPACE, "samlp:" + SamlXmlConstants.AUTHN_REQUEST_ELEMENT);
        authnRequestElement.setAttributeNS("http://www.w3.org/2000/xmlns/", "xmlns:samlp", SamlXmlConstants.SAML_PROTOCOL_NAMESPACE);
        parent.appendChild(authnRequestElement);
        handleAbstractRequest(authnRequestElement, authnRequest);

        if (authnRequest.hasAssertionConsumerServiceURL()) {
            authnRequestElement.setAttribute(SamlXmlConstants.ASSERTION_CONSUMER_SERVICE_URL_ATTR, authnRequest.getAssertionConsumerServiceUrl());
        } else if (authnRequest.hasAssertionConsumerServiceIndex()) {
            authnRequestElement.setAttribute(SamlXmlConstants.ASSERTION_CONSUMER_SERVICE_INDEX_ATTR, authnRequest.getAssertionConsumerServiceIndex());
        }

        if (authnRequest.hasProtocolBinding()) {
            authnRequestElement.setAttribute(SamlXmlConstants.PROTOCOL_BINDING_ATTR, authnRequest.getProtocolBinding());
        }

        if (authnRequest.hasIssuer()) {
            createIssuerElement(authnRequest.getIssuer(), authnRequestElement);
        }

        Node nameIdPolicyNode = null;
        if (authnRequest.hasNameIdPolicy()) {
            nameIdPolicyNode = createNameIdPolicyElement(authnRequest.getNameIdPolicy(), authnRequestElement);
        }
        signDocument(ownerDocument, nameIdPolicyNode, privateKey, authnRequest.getId());
    }

    /**
     * Creates an xml element for the given Issuer.
     *
     * @param issuer Issuer to create the element for.
     * @param parent Node to add the element to.
     */
    protected void createIssuerElement(final Issuer issuer, final Node parent) {
        final Document ownerDoc = parent.getOwnerDocument();
        final Element issuerElement = ownerDoc.createElementNS(SamlXmlConstants.SAML_ASSERTION_NAMESPACE, "saml:" + SamlXmlConstants.ISSUER_ELEMENT);
        issuerElement.setAttributeNS("http://www.w3.org/2000/xmlns/", "xmlns:saml", SamlXmlConstants.SAML_ASSERTION_NAMESPACE);
        parent.appendChild(issuerElement);
        issuerElement.setTextContent(issuer.getValue());
    }

    /**
     * Creates an xml element for the given NameIdPolicy.
     *
     * @param nameIdPolicy NameIdPolicy to create the xml element for.
     * @param parent       Node the created element is added to.
     * @return The NameIDPolicy Node
     */
    protected Node createNameIdPolicyElement(final NameIdPolicy nameIdPolicy, final Node parent) {
        final Document ownerDoc = parent.getOwnerDocument();
        final Element nameIdPolicyElement = ownerDoc.createElementNS(SamlXmlConstants.SAML_PROTOCOL_NAMESPACE, "samlp:" + SamlXmlConstants.NAME_ID_POLICY_ELEMENT);
        nameIdPolicyElement.setAttributeNS("http://www.w3.org/2000/xmlns/", "xmlns:samlp", SamlXmlConstants.SAML_PROTOCOL_NAMESPACE);
        if (nameIdPolicy.hasFormat()) {
            nameIdPolicyElement.setAttribute(SamlXmlConstants.FORMAT_ATTR, nameIdPolicy.getFormat());
        }
        if (nameIdPolicy.hasAllowCreate()) {
            nameIdPolicyElement.setAttribute(SamlXmlConstants.ALLOW_CREATE_ATTR, Boolean.toString(nameIdPolicy.isAllowCreate()));
        }
        if (nameIdPolicy.hasSpNameQualifier()) {
            nameIdPolicyElement.setAttribute(SamlXmlConstants.SP_NAME_QUALIFIER_ATTR, nameIdPolicy.getSpNameQualifier());
        }
        parent.appendChild(nameIdPolicyElement);
        return nameIdPolicyElement;
    }

    /**
     * Adds attributes and sub-elements common to all AbstractRequest instances to the given request element.
     *
     * @param requestElement Request element to add data to.
     * @param request        AbstractRequest instance which the requestElement should be modified for.
     */
    protected void handleAbstractRequest(final Element requestElement, final AbstractRequest request) {
        // set required attributes
        requestElement.setAttribute(SamlXmlConstants.VERSION_ATTR, request.getVersion());
        requestElement.setAttribute(SamlXmlConstants.ID_ATTR, request.getId());
        requestElement.setIdAttribute(SamlXmlConstants.ID_ATTR, true);

        requestElement.setAttribute(SamlXmlConstants.ISSUE_INSTANT_ATTR, SamlXmlConstants.XML_DATE_FORMATTER.print(request.getIssueInstant().getTimeInMillis()));

        // set optional attributes
        if (request.hasConsent()) {
            requestElement.setAttribute(SamlXmlConstants.CONSENT_ATTR, request.getConsent());
        }
        if (request.hasDestination()) {
            requestElement.setAttribute(SamlXmlConstants.DESTINATION_ATTR, request.getDestination());
        }

    }

    protected void createLogoutRequestElement(LogoutRequest request, Node parent, Key privateKey) throws SamlWriterException {
        final Document ownerDocument;
        if (parent instanceof Document) {
            ownerDocument = (Document) parent;
        } else {
            ownerDocument = parent.getOwnerDocument();
        }

        final Element requestElement = ownerDocument.createElementNS(SamlXmlConstants.SAML_PROTOCOL_NAMESPACE, "samlp:" + SamlXmlConstants.LOGOUT_REQUEST);
        requestElement.setAttributeNS("http://www.w3.org/2000/xmlns/", "xmlns:samlp", SamlXmlConstants.SAML_PROTOCOL_NAMESPACE);
        parent.appendChild(requestElement);
        handleAbstractRequest(requestElement, request);

        Element issuerElement = ownerDocument.createElementNS(SamlXmlConstants.SAML_ASSERTION_NAMESPACE, "saml:" + SamlXmlConstants.ISSUER_ELEMENT);
        issuerElement.setAttributeNS("http://www.w3.org/2000/xmlns/", "xmlns:saml", SamlXmlConstants.SAML_ASSERTION_NAMESPACE);

        issuerElement.setTextContent(request.getIssuer().getValue());
        requestElement.appendChild(issuerElement);

        Element nameIdElement = null;
        if (request.getNameId() != null) {
            nameIdElement = ownerDocument.createElementNS(SamlXmlConstants.SAML_ASSERTION_NAMESPACE, "saml:" + SamlXmlConstants.NAME_ID_ELEMENT);
            nameIdElement.setAttributeNS("http://www.w3.org/2000/xmlns/", "xmlns:saml", SamlXmlConstants.SAML_ASSERTION_NAMESPACE);
            nameIdElement.setTextContent(request.getNameId().toString());
            if (request.getNameIdFormat() != null) {
                nameIdElement.setAttribute(SamlXmlConstants.FORMAT_ATTR, request.getNameIdFormat());
            }
            if (request.getNameQualifier() != null) {
                nameIdElement.setAttribute(SamlXmlConstants.NAME_QUALIFIER_ATTR, request.getNameQualifier());
            }
            if (request.getSpNameQualifier() != null) {
                nameIdElement.setAttribute(SamlXmlConstants.SP_NAME_QUALIFIER_ATTR, request.getSpNameQualifier());
            }
            requestElement.appendChild(nameIdElement);
        }

        for (String sessionIndex : request.getSessionIndices()) {
            Element sessionIndexElement = ownerDocument.createElementNS(SamlXmlConstants.SAML_PROTOCOL_NAMESPACE, "samlp:" + SamlXmlConstants.SESSION_INDEX_ATTR);
            sessionIndexElement.setAttributeNS("http://www.w3.org/2000/xmlns/", "xmlns:samlp", SamlXmlConstants.SAML_PROTOCOL_NAMESPACE);
            sessionIndexElement.setTextContent(sessionIndex);
            requestElement.appendChild(sessionIndexElement);
        }
        signDocument(ownerDocument, nameIdElement, privateKey, request.getId());
    }

    protected void signDocument(final Document ownerDocument, final Node placeSignatureBefore,
                                final Key privateKey, final String messageId) throws SamlWriterException {
        if (privateKey != null) {
            DOMSignContext domSignContext = null;
            if (placeSignatureBefore != null) {
                domSignContext = new DOMSignContext(privateKey, ownerDocument.getDocumentElement(), placeSignatureBefore);
            } else {
                domSignContext = new DOMSignContext(privateKey, ownerDocument.getDocumentElement());
            }

            XMLSignatureFactory signatureFactory = XMLSignatureFactory.getInstance("DOM");
            try {
                LinkedList<Transform> transforms = new LinkedList<Transform>();
                transforms.add(signatureFactory.newTransform(Transform.ENVELOPED, (TransformParameterSpec) null));
                transforms.add(signatureFactory.newTransform(Transforms.TRANSFORM_C14N_EXCL_OMIT_COMMENTS, (TransformParameterSpec) null));
                Reference reference = signatureFactory.newReference("#" + messageId, signatureFactory.newDigestMethod(DIGEST_METHOD, null),
                        transforms, null, null);
                SignedInfo signedInfo = signatureFactory.newSignedInfo(signatureFactory.newCanonicalizationMethod(CanonicalizationMethod.EXCLUSIVE,
                        (C14NMethodParameterSpec) null), signatureFactory.newSignatureMethod(SIGNATURE_METHOD, null), Collections.singletonList(reference));
                XMLSignature xmlSignature = signatureFactory.newXMLSignature(signedInfo, null);
                xmlSignature.sign(domSignContext);
            } catch (XMLSignatureException e) {
                throw new SamlWriterException("XMLSignature exception while signing document.", e);
            } catch (MarshalException e) {
                throw new SamlWriterException("MarshalException while signing document.", e);
            } catch (NoSuchAlgorithmException e) {
                throw new SamlWriterException("Signature Algorithm not available.", e);
            } catch (InvalidAlgorithmParameterException e) {
                throw new SamlWriterException("Invalid parameter for signature algorithm.", e);
            }
        }
    }
}
