/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2008, Red Hat Middleware LLC, and individual contributors
 * as indicated by the @author tags. See the copyright.txt file in the
 * distribution for a full listing of individual contributors. 
 *
 * This is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of
 * the License, or (at your option) any later version.
 *
 * This software is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this software; if not, write to the Free
 * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
 * 02110-1301 USA, or see the FSF site: http://www.fsf.org.
 */
package org.jboss.identity.federation.web.filters;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.GeneralSecurityException;
import java.security.Principal;
import java.security.PublicKey;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import javax.xml.bind.JAXBElement;
import javax.xml.bind.JAXBException;
import javax.xml.crypto.MarshalException;
import javax.xml.crypto.dsig.XMLSignatureException;

import org.apache.log4j.Logger;
import org.jboss.identity.federation.api.saml.v2.request.SAML2Request;
import org.jboss.identity.federation.api.saml.v2.response.SAML2Response;
import org.jboss.identity.federation.core.config.KeyProviderType;
import org.jboss.identity.federation.core.config.SPType;
import org.jboss.identity.federation.core.config.TrustType;
import org.jboss.identity.federation.core.exceptions.ConfigurationException;
import org.jboss.identity.federation.core.exceptions.ParsingException;
import org.jboss.identity.federation.core.saml.v2.common.IDGenerator;
import org.jboss.identity.federation.core.saml.v2.common.SAMLDocumentHolder;
import org.jboss.identity.federation.core.saml.v2.constants.JBossSAMLURIConstants;
import org.jboss.identity.federation.core.saml.v2.exceptions.AssertionExpiredException;
import org.jboss.identity.federation.core.saml.v2.exceptions.IssuerNotTrustedException;
import org.jboss.identity.federation.core.saml.v2.holders.DestinationInfoHolder;
import org.jboss.identity.federation.core.saml.v2.util.AssertionUtil;
import org.jboss.identity.federation.core.util.XMLSignatureUtil;
import org.jboss.identity.federation.saml.v2.assertion.AssertionType;
import org.jboss.identity.federation.saml.v2.assertion.AttributeStatementType;
import org.jboss.identity.federation.saml.v2.assertion.AttributeType;
import org.jboss.identity.federation.saml.v2.assertion.EncryptedElementType;
import org.jboss.identity.federation.saml.v2.assertion.NameIDType;
import org.jboss.identity.federation.saml.v2.assertion.SubjectType;
import org.jboss.identity.federation.saml.v2.protocol.AuthnRequestType;
import org.jboss.identity.federation.saml.v2.protocol.ResponseType;
import org.jboss.identity.federation.saml.v2.protocol.StatusType;
import org.jboss.identity.federation.web.interfaces.IRoleValidator;
import org.jboss.identity.federation.web.interfaces.TrustKeyConfigurationException;
import org.jboss.identity.federation.web.interfaces.TrustKeyManager;
import org.jboss.identity.federation.web.interfaces.TrustKeyProcessingException;
import org.jboss.identity.federation.web.roles.DefaultRoleValidator;
import org.jboss.identity.federation.web.util.ConfigurationUtil;
import org.jboss.identity.federation.web.util.PostBindingUtil;
import org.w3c.dom.Document;
import org.xml.sax.SAXException;

/**
 * @author Anil.Saldhana@redhat.com
 * @since Aug 21, 2009
 */
public class SPFilter implements Filter
{ 
   private static Logger log = Logger.getLogger(SPFilter.class);
   private boolean trace = log.isTraceEnabled();

   public static final String PRINCIPAL_ID = "jboss_identity.principal"; 
   public static final String ROLES_ID = "jboss_identity.roles";

   protected SPType spConfiguration = null;
   protected String configFile = "/WEB-INF/jboss-idfed.xml";

   protected String serviceURL = null;
   protected String identityURL = null;

   private TrustKeyManager keyManager;
   
   private ServletContext context = null;
   
   private IRoleValidator roleValidator = new DefaultRoleValidator();

   public void destroy()
   {
   }

   public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, 
         FilterChain filterChain) 
   throws IOException, ServletException
   {
      HttpServletRequest request = (HttpServletRequest) servletRequest;
      HttpServletResponse response = (HttpServletResponse) servletResponse;
      
      boolean postMethod = "POST".equalsIgnoreCase(request.getMethod());
      Principal userPrincipal = null;
      
      HttpSession session = request.getSession();
      if(!postMethod)
      {
         //Check if we are already authenticated
         userPrincipal = (Principal) session.getAttribute(PRINCIPAL_ID);
         if(userPrincipal != null)
         {
            filterChain.doFilter(servletRequest, servletResponse);
            return;
         }
         
         //We need to send request to IDP
         if(userPrincipal == null)
         {
            String relayState = null;
            try
            {
               AuthnRequestType authnRequest = createSAMLRequest(serviceURL, identityURL);
               sendRequestToIDP(authnRequest, relayState, response);
            }
            catch (Exception e)
            {
               throw new ServletException(e);
            } 
            return;
         } 
      }
      else
      {
         //See if we got a response from IDP
         String samlResponse = request.getParameter("SAMLResponse"); 
         if(samlResponse != null && samlResponse.length() > 0 )
         {
            boolean isValid = false;
            try
            {
               isValid = this.validate(request);
            }
            catch (Exception e)
            {
               throw new ServletException(e);
            }
            if(!isValid)
               throw new ServletException("Validity check failed");
            
            //deal with SAML response from IDP
            byte[] base64DecodedResponse = PostBindingUtil.base64Decode(samlResponse);
            InputStream is = new ByteArrayInputStream(base64DecodedResponse);

            try
            {
               SAML2Response saml2Response = new SAML2Response();
               
               ResponseType responseType = saml2Response.getResponseType(is);
               
               SAMLDocumentHolder samlDocumentHolder = saml2Response.getSamlDocumentHolder();
               
               boolean validSignature = this.verifySignature(samlDocumentHolder);
               
               if(validSignature == false)
                  throw new IssuerNotTrustedException("Signature in saml document is invalid");
               
               this.isTrusted(responseType.getIssuer().getValue());
               
               List<Object> assertions = responseType.getAssertionOrEncryptedAssertion();
               if(assertions.size() == 0)
                  throw new IllegalStateException("No assertions in reply from IDP"); 
               
               Object assertion = assertions.get(0);
               if(assertion instanceof EncryptedElementType)
               {
                  responseType = this.decryptAssertion(responseType);
               }
               
               userPrincipal = handleSAMLResponse(request, responseType);
               if(userPrincipal == null)
                  response.sendError(HttpServletResponse.SC_FORBIDDEN);
               
               filterChain.doFilter(request, servletResponse);
            }
            catch (ParsingException e)
            {
               if(trace)
                  log.trace("Parsing Exception:", e);
               throw new ServletException("Parsing Exception");
            }
            catch (ConfigurationException e)
            {
               if(trace)
                  log.trace("ConfigurationException:", e);
               throw new ServletException("Config Exception");
            }
            catch (IssuerNotTrustedException e)
            {
               if(trace)
                  log.trace("IssuerNotTrustedException:", e);
               throw new ServletException("Issuer Not Trusted Exception");
            }
            catch (AssertionExpiredException e)
            {
               if(trace)
                  log.trace("AssertionExpiredException:", e);
               throw new ServletException("Assertion expired Exception");
            } 
         } 
         
      }
      
   }

   public void init(FilterConfig filterConfig) throws ServletException
   {
      this.context = filterConfig.getServletContext();
      InputStream is = context.getResourceAsStream(configFile);
      if(is == null)
         throw new RuntimeException(configFile + " missing");
      try
      {
         spConfiguration = ConfigurationUtil.getSPConfiguration(is);
         this.identityURL = spConfiguration.getIdentityURL();
         this.serviceURL = spConfiguration.getServiceURL();
         log.trace("Identity Provider URL=" + this.identityURL); 
      }
      catch (Exception e)
      {
         throw new RuntimeException(e);
      }
      KeyProviderType keyProvider = this.spConfiguration.getKeyProvider();
      if(keyProvider == null)
         throw new RuntimeException("KeyProvider is null");
      try
      {
         ClassLoader tcl = SecurityActions.getContextClassLoader();
         String keyManagerClassName = keyProvider.getClassName();
         if(keyManagerClassName == null)
            throw new RuntimeException("KeyManager class name is null");
         
         Class<?> clazz = tcl.loadClass(keyManagerClassName);
         this.keyManager = (TrustKeyManager) clazz.newInstance();
         keyManager.setAuthProperties(keyProvider.getAuth());
         keyManager.setValidatingAlias(keyProvider.getValidatingAlias());
      }
      catch(Exception e)
      {
         log.error("Exception reading configuration:",e);
         throw new RuntimeException(e.getLocalizedMessage());
      }
      log.trace("Key Provider=" + keyProvider.getClassName());
      
      //Get the Role Validator if configured
      String roleValidatorName = filterConfig.getInitParameter("ROLE_VALIDATOR");
      if(roleValidatorName != null && !"".equals(roleValidatorName))
      {
         try
         {
            Class<?> clazz = SecurityActions.getContextClassLoader().loadClass(roleValidatorName);
            this.roleValidator = (IRoleValidator) clazz.newInstance();
         }
         catch (Exception e)
         {
            throw new RuntimeException(e);
         } 
      }
      
      Map<String,String> options = new HashMap<String, String>();
      String roles = filterConfig.getInitParameter("ROLES");
      if(trace)
         log.trace("Found Roles in SPFilter config="+roles);
      if(roles != null)
      {
         options.put("ROLES", roles);
      }
      this.roleValidator.intialize(options); 
   }

   /**
    * Create a SAML2 auth request
    * @param serviceURL URL of the service
    * @param identityURL URL of the identity provider
    * @return   
    * @throws ConfigurationException 
    */
   private AuthnRequestType createSAMLRequest(String serviceURL, String identityURL) throws ConfigurationException
   {
      if(serviceURL == null)
         throw new IllegalArgumentException("serviceURL is null");
      if(identityURL == null)
         throw new IllegalArgumentException("identityURL is null");
      
      SAML2Request saml2Request = new SAML2Request();
      String id = IDGenerator.create("ID_");
      return saml2Request.createAuthnRequestType(id, serviceURL, identityURL, serviceURL); 
   }
   
   protected void sendRequestToIDP(AuthnRequestType authnRequest, String relayState, 
         HttpServletResponse response)
   throws IOException, SAXException, JAXBException,GeneralSecurityException
   {
      SAML2Request saml2Request = new SAML2Request();
      ByteArrayOutputStream baos = new ByteArrayOutputStream();
      saml2Request.marshall(authnRequest, baos);
 
      String samlMessage = PostBindingUtil.base64Encode(baos.toString());  
      String destination = authnRequest.getDestination();
      PostBindingUtil.sendPost(new DestinationInfoHolder(destination, samlMessage, relayState),
             response, true);
   }
   
   protected boolean validate(HttpServletRequest request) throws IOException, GeneralSecurityException
   {
      return request.getParameter("SAMLResponse") != null; 
   }
    
   protected boolean verifySignature(SAMLDocumentHolder samlDocumentHolder) throws IssuerNotTrustedException
   {   
      Document samlResponse = samlDocumentHolder.getSamlDocument();
      ResponseType response = (ResponseType) samlDocumentHolder.getSamlObject();
      
      String issuerID = response.getIssuer().getValue();
      
      if(issuerID == null)
         throw new IssuerNotTrustedException("Issue missing");
      
      URL issuerURL;
      try
      {
         issuerURL = new URL(issuerID);
      }
      catch (MalformedURLException e1)
      {
         throw new IssuerNotTrustedException(e1);
      }
      
      try
      {
         PublicKey publicKey = keyManager.getValidatingKey(issuerURL.getHost());
         log.trace("Going to verify signature in the saml response from IDP"); 
         boolean sigResult =  XMLSignatureUtil.validate(samlResponse, publicKey);
         log.trace("Signature verification="+sigResult);
         return sigResult;
      }
      catch (TrustKeyConfigurationException e)
      {
         log.error("Unable to verify signature",e);
      }
      catch (TrustKeyProcessingException e)
      {
         log.error("Unable to verify signature",e);
      }
      catch (MarshalException e)
      {
         log.error("Unable to verify signature",e);
      }
      catch (XMLSignatureException e)
      {
         log.error("Unable to verify signature",e);
      }
      return false;
   }  
   
   protected void isTrusted(String issuer) throws IssuerNotTrustedException
   {
      try
      {
         URL url = new URL(issuer);
         String issuerDomain = url.getHost(); 
         TrustType idpTrust =  spConfiguration.getTrust();
         if(idpTrust != null)
         {
            String domainsTrusted = idpTrust.getDomains();
            if(domainsTrusted.indexOf(issuerDomain) < 0)
               throw new IssuerNotTrustedException(issuer); 
         }
      }
      catch (Exception e)
      {
         throw new IssuerNotTrustedException(e.getLocalizedMessage(),e);
      }
   }
   
   protected ResponseType decryptAssertion(ResponseType responseType)
   {
      throw new RuntimeException("This authenticator does not handle encryption");
   }
   
   /**
    * Handle the SAMLResponse from the IDP
    * @param request entire request from IDP
    * @param responseType ResponseType that has been generated
    * @param serverEnvironment tomcat,jboss etc
    * @return   
    * @throws AssertionExpiredException 
    */
   @SuppressWarnings("unchecked")
   public Principal handleSAMLResponse(HttpServletRequest request, ResponseType responseType) 
   throws ConfigurationException, AssertionExpiredException
   {
      if(request == null)
         throw new IllegalArgumentException("request is null");
      if(responseType == null)
         throw new IllegalArgumentException("response type is null");
      
      StatusType statusType = responseType.getStatus();
      if(statusType == null)
         throw new IllegalArgumentException("Status Type from the IDP is null");

      String statusValue = statusType.getStatusCode().getValue();
      if(JBossSAMLURIConstants.STATUS_SUCCESS.get().equals(statusValue) == false)
         throw new SecurityException("IDP forbid the user");

      List<Object> assertions = responseType.getAssertionOrEncryptedAssertion();
      if(assertions.size() == 0)
         throw new IllegalStateException("No assertions in reply from IDP"); 
      
      AssertionType assertion = (AssertionType)assertions.get(0);
      //Check for validity of assertion
      boolean expiredAssertion = AssertionUtil.hasExpired(assertion);
      if(expiredAssertion)
         throw new AssertionExpiredException();
      
      SubjectType subject = assertion.getSubject(); 
      JAXBElement<NameIDType> jnameID = (JAXBElement<NameIDType>) subject.getContent().get(0);
      NameIDType nameID = jnameID.getValue();
      final String userName = nameID.getValue();
      List<String> roles = new ArrayList<String>();

      //Let us get the roles
      AttributeStatementType attributeStatement = (AttributeStatementType) assertion.getStatementOrAuthnStatementOrAuthzDecisionStatement().get(0);
      List<Object> attList = attributeStatement.getAttributeOrEncryptedAttribute();
      for(Object obj:attList)
      {
         AttributeType attr = (AttributeType) obj;
         String roleName = (String) attr.getAttributeValue().get(0);
         roles.add(roleName);
      }
      
      Principal principal = new Principal()
      {
         public String getName()
         {
            return userName;
         }
      };     
      
      //Validate the roles
      boolean validRole = roleValidator.userInRole(principal, roles);
      if(!validRole)
      {
         if(trace)
            log.trace("Invalid role:" + roles);
         principal = null;
      }
      return principal;
   } 
}