package org.apereo.cas.support.saml.web.idp.profile.ecp;

import org.apereo.cas.authentication.Authentication;
import org.apereo.cas.authentication.AuthenticationException;
import org.apereo.cas.authentication.Credential;
import org.apereo.cas.authentication.credential.UsernamePasswordCredential;
import org.apereo.cas.support.saml.SamlIdPConstants;
import org.apereo.cas.support.saml.SamlIdPUtils;
import org.apereo.cas.support.saml.SamlProtocolConstants;
import org.apereo.cas.support.saml.web.idp.profile.AbstractSamlIdPProfileHandlerController;
import org.apereo.cas.support.saml.web.idp.profile.SamlProfileHandlerConfigurationContext;
import org.apereo.cas.support.saml.web.idp.profile.builders.SamlProfileBuilderContext;
import org.apereo.cas.util.CollectionUtils;
import org.apereo.cas.util.LoggingUtils;

import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang3.tuple.Pair;
import org.opensaml.messaging.context.MessageContext;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.soap.messaging.context.SOAP11Context;
import org.pac4j.core.context.CallContext;
import org.pac4j.core.credentials.UsernamePasswordCredentials;
import org.pac4j.core.credentials.extractor.BasicAuthExtractor;
import org.pac4j.jee.context.JEEContext;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.PostMapping;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.util.LinkedHashMap;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

/**
 * This is {@link ECPSamlIdPProfileHandlerController}.
 *
 * @author Misagh Moayyed
 * @since 5.1.0
 */
@Slf4j
public class ECPSamlIdPProfileHandlerController extends AbstractSamlIdPProfileHandlerController {
    public ECPSamlIdPProfileHandlerController(final SamlProfileHandlerConfigurationContext configurationContext) {
        super(configurationContext);
    }

    /**
     * Handle ecp request.
     *
     * @param response the response
     * @param request  the request
     * @throws Exception the exception
     */
    @PostMapping(path = SamlIdPConstants.ENDPOINT_SAML2_IDP_ECP_PROFILE_SSO,
        consumes = {MediaType.TEXT_XML_VALUE, SamlIdPConstants.ECP_SOAP_PAOS_CONTENT_TYPE},
        produces = {MediaType.TEXT_XML_VALUE, SamlIdPConstants.ECP_SOAP_PAOS_CONTENT_TYPE})
    public void handleEcpRequest(final HttpServletResponse response,
                                 final HttpServletRequest request) throws Exception {
        val soapContext = decodeSoapRequest(request);
        val credential = extractBasicAuthenticationCredential(request, response);

        if (credential == null) {
            LOGGER.error("Credentials could not be extracted from the SAML ECP request");
            response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
            return;
        }
        if (soapContext == null) {
            LOGGER.error("SAML ECP request could not be determined from the authentication request");
            return;
        }

        val buildContext = SamlProfileBuilderContext.builder()
            .httpRequest(request)
            .httpResponse(response)
            .binding(SAMLConstants.SAML2_PAOS_BINDING_URI)
            .messageContext(soapContext)
            .build();
        handleEcpRequest(buildContext, credential);
    }

    /**
     * Handle ecp request.
     *
     * @param context    the context
     * @param credential the credential
     * @throws Exception the exception
     */
    protected void handleEcpRequest(final SamlProfileBuilderContext context, final Credential credential) throws Exception {
        LOGGER.debug("Handling ECP request for SOAP context [{}]", context.getMessageContext());

        val envelope = context.getMessageContext().getSubcontext(SOAP11Context.class).getEnvelope();
        getConfigurationContext().getOpenSamlConfigBean().logObject(envelope);

        val authnRequest = (AuthnRequest) context.getMessageContext().getMessage();
        val authenticationContext = Pair.of(authnRequest, context.getMessageContext());
        try {
            LOGGER.trace("Verifying ECP authentication request [{}]", authnRequest);
            val serviceRequest = verifySamlAuthenticationRequest(authenticationContext, context.getHttpRequest());

            LOGGER.trace("Attempting to authenticate ECP request for credential id [{}]", credential.getId());
            val authentication = authenticateEcpRequest(credential, authenticationContext);
            LOGGER.debug("Authenticated [{}] successfully with authenticated principal [{}]",
                credential.getId(), authentication.getPrincipal());

            LOGGER.trace("Building ECP SAML response for [{}]", credential.getId());
            val issuer = SamlIdPUtils.getIssuerFromSamlObject(authnRequest);
            val service = getConfigurationContext().getWebApplicationServiceFactory().createService(issuer);
            val casAssertion = buildCasAssertion(authentication, service, serviceRequest.getKey(), new LinkedHashMap<>(0));

            LOGGER.trace("CAS assertion to use for building ECP SAML2 response is [{}]", casAssertion);
            buildSamlResponse(context.getHttpResponse(), context.getHttpRequest(),
                authenticationContext, Optional.of(casAssertion), context.getBinding());
        } catch (final AuthenticationException e) {
            LoggingUtils.error(LOGGER, e);
            val error = e.getHandlerErrors().values()
                .stream()
                .map(Throwable::getMessage)
                .filter(Objects::nonNull)
                .collect(Collectors.joining(","));
            buildEcpFaultResponse(context, error);
        } catch (final Exception e) {
            LoggingUtils.error(LOGGER, e);
            buildEcpFaultResponse(context, e.getMessage());
        }
    }

    /**
     * Build ecp fault response.
     *
     * @param context the context
     * @param error   the error
     * @throws Exception the exception
     */
    protected void buildEcpFaultResponse(final SamlProfileBuilderContext context, final String error) throws Exception {
        context.getHttpRequest().setAttribute(SamlIdPConstants.REQUEST_ATTRIBUTE_ERROR, error);
        getConfigurationContext().getSamlFaultResponseBuilder().build(context);
    }

    /**
     * Authenticate ecp request.
     *
     * @param credential   the credential
     * @param authnRequest the authn request
     * @return the authentication
     */
    protected Authentication authenticateEcpRequest(final Credential credential,
                                                    final Pair<AuthnRequest, MessageContext> authnRequest) {
        val issuer = SamlIdPUtils.getIssuerFromSamlObject(authnRequest.getKey());
        LOGGER.debug("Located issuer [{}] from request prior to authenticating [{}]", issuer, credential.getId());

        val service = getConfigurationContext().getWebApplicationServiceFactory().createService(issuer);
        service.getAttributes().put(SamlProtocolConstants.PARAMETER_ENTITY_ID, CollectionUtils.wrapList(issuer));
        LOGGER.debug("Executing authentication request for service [{}] on behalf of credential id [{}]", service, credential.getId());
        val authenticationResult = getConfigurationContext()
            .getAuthenticationSystemSupport().finalizeAuthenticationTransaction(service, credential);
        return authenticationResult.getAuthentication();
    }

    private Credential extractBasicAuthenticationCredential(final HttpServletRequest request,
                                                            final HttpServletResponse response) {
        val extractor = new BasicAuthExtractor();
        val webContext = new JEEContext(request, response);

        val callContext = new CallContext(webContext, configurationContext.getSessionStore());
        val credentialsResult = extractor.extract(callContext);
        if (credentialsResult.isPresent()) {
            val credentials = (UsernamePasswordCredentials) credentialsResult.get();
            LOGGER.debug("Received basic authentication ECP request from credentials [{}]", credentials);
            return new UsernamePasswordCredential(credentials.getUsername(), credentials.getPassword());
        }
        return null;
    }
}
