package org.mule.connectivity.restconnect.internal.modelGeneration.ramlParser.security;

import edu.emory.mathcs.backport.java.util.Arrays;
import org.mule.connectivity.restconnect.exception.UnsupportedSecuritySchemeException;
import org.mule.connectivity.restconnect.internal.model.parameter.Parameter;
import org.mule.connectivity.restconnect.internal.model.security.*;
import org.mule.connectivity.restconnect.internal.modelGeneration.JsonSchemaPool;
import org.mule.connectivity.restconnect.internal.modelGeneration.common.security.SecuritySchemeFactory;
import org.raml.v2.api.model.v10.api.Api;
import org.raml.v2.api.model.v10.datamodel.TypeInstance;
import org.raml.v2.api.model.v10.datamodel.TypeInstanceProperty;
import org.raml.v2.api.model.v10.declarations.AnnotationRef;
import org.raml.v2.api.model.v10.methods.Method;
import org.raml.v2.api.model.v10.resources.Resource;
import org.raml.v2.api.model.v10.security.SecurityScheme;
import org.raml.v2.api.model.v10.security.SecuritySchemePart;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;

import static java.lang.Boolean.FALSE;
import static java.util.stream.Collectors.toList;
import static org.mule.connectivity.restconnect.internal.model.parameter.ParameterType.HEADER;
import static org.mule.connectivity.restconnect.internal.model.parameter.ParameterType.QUERY;
import static org.mule.connectivity.restconnect.internal.model.security.OAuth2Scheme.OAUTH2_GRANT_AUTHORIZATION_CODE;
import static org.mule.connectivity.restconnect.internal.model.security.OAuth2Scheme.OAUTH2_GRANT_CLIENT_CREDENTIALS;
import static org.mule.connectivity.restconnect.internal.modelGeneration.ramlParser.util.RamlParserUtils.getAnnotatedRenewTokenExpression;
import static org.mule.connectivity.restconnect.internal.modelGeneration.ramlParser.util.RamlParserUtils.getParameterList;
import static org.mule.connectivity.restconnect.internal.modelGeneration.ramlParser.util.RamlParserUtils.getValueFromAnnotableString;
import static org.mule.connectivity.restconnect.internal.model.security.APISecurityScheme.getSecuritySchemeParameter;
import static org.mule.connectivity.restconnect.internal.model.typesource.PrimitiveTypeSource.PrimitiveType.STRING;

public class RamlParserSecuritySchemeFactory {

    public static List<APISecurityScheme> getOperationSecuritySchemes(Api api, Method method, JsonSchemaPool jsonSchemaPool) throws Exception {
        List<SecurityScheme> globalSchemes = api.securedBy().stream().map(x -> x != null ? x.securityScheme() : null).collect(toList());

        Resource endPoint = method.resource();
        List<SecurityScheme> endPointSchemes = endPoint.securedBy().stream().map(x ->  x != null ? x.securityScheme() : null).collect(toList());

        return getMethodSecuritySchemes(method, endPointSchemes, globalSchemes, jsonSchemaPool);
    }

    @SuppressWarnings("Duplicates")
    private static List<APISecurityScheme> getAPISecuritySchemes(List<SecurityScheme> sourceSchemes, JsonSchemaPool jsonSchemaPool) throws Exception {
        List<APISecurityScheme> returnSchemes = new LinkedList<>();

        for (SecurityScheme scheme : sourceSchemes) {
            List<APISecurityScheme> generatedSchemes = createSecuritySchemes(scheme, jsonSchemaPool);

            for(APISecurityScheme generatedScheme : generatedSchemes){
                if(returnSchemes.stream().noneMatch(x -> x.equals(generatedScheme))){
                    returnSchemes.add(generatedScheme);
                }
            }
        }

        // If the API is secured but we don't support any defined schemes, then we must throw a generation exception.
        if(!sourceSchemes.isEmpty() && returnSchemes.isEmpty()){
            throw new UnsupportedSecuritySchemeException("None of the specified security schemes ( " + listSchemes(sourceSchemes) + ") are supported.");
        }

        return returnSchemes;
    }

    @SuppressWarnings("Duplicates")
    private static List<APISecurityScheme> getMethodSecuritySchemes(Method method, List<SecurityScheme> endPointSchemes, List<SecurityScheme> globalSchemes, JsonSchemaPool jsonSchemaPool) throws Exception {
        List<SecurityScheme> methodSchemes = method.securedBy().stream().map(x -> x != null ? x.securityScheme() : null).collect(toList());

        List<SecurityScheme> securitySchemesForOperation = SecuritySchemeFactory.getSecuritySchemesForOperation(methodSchemes, endPointSchemes, globalSchemes);

        if(!securitySchemesForOperation.isEmpty()){
            return getAPISecuritySchemes(securitySchemesForOperation, jsonSchemaPool);
        }
        else{
            //If there are no security schemes defined for this operation, we create an UnsecuredScheme
            List<APISecurityScheme> returnSchemes = new LinkedList<>();
            returnSchemes.add(new UnsecuredScheme());
            return returnSchemes;
        }
    }


    private static List<APISecurityScheme> createSecuritySchemes(SecurityScheme securityScheme, JsonSchemaPool jsonSchemaPool) throws Exception {
        LinkedList<APISecurityScheme> apiSecuritySchemes = new LinkedList<>();

        if (securityScheme == null) {
            apiSecuritySchemes.add(new UnsecuredScheme());
            return apiSecuritySchemes;
        }

        String schemeType = securityScheme.type();
        if (RamlParserSecuritySchemesNaming.isBasicAuth(schemeType)) {
            apiSecuritySchemes.add(new BasicAuthScheme());
        } else if (RamlParserSecuritySchemesNaming.isPassThrough(schemeType)) {
            apiSecuritySchemes.add(buildPassThroughSecurityScheme(securityScheme, jsonSchemaPool));
        } else if (RamlParserSecuritySchemesNaming.isOauth2(schemeType)) {
            for (String grant : securityScheme.settings().authorizationGrants()) {

                if (RamlParserOauth2FlowsNaming.isAuthorizationCode(grant)) {
                    apiSecuritySchemes.add(buildOAuth2AuthorizationCodeSecurityScheme(securityScheme));
                } else if (RamlParserOauth2FlowsNaming.isClientCredentials(grant)) {
                    apiSecuritySchemes.add(buildOAuth2ClientCredentialsSecurityScheme(securityScheme));
                }
            }
        } else if (RamlParserSecuritySchemesNaming.isDigestAuth(schemeType)) {
            apiSecuritySchemes.add(new DigestAuthenticationScheme());
        }
        else if (RamlParserSecuritySchemesNaming.isJwtAuth(schemeType)) {
            apiSecuritySchemes.add(buildJwtAuthenticationScheme(securityScheme, jsonSchemaPool));
        }
        else if(RamlParserSecuritySchemesNaming.isCustom(schemeType)){
            // If none of the previous types matches, it is a custom security scheme
            // https://github.com/raml-org/raml-spec/blob/master/versions/raml-10/raml-10.md#security-scheme-declaration
            apiSecuritySchemes.add(buildCustomAuthenticationSecurityScheme(securityScheme, jsonSchemaPool));
        }

        return apiSecuritySchemes;
    }

    private static PassThroughScheme buildPassThroughSecurityScheme(SecurityScheme securityScheme, JsonSchemaPool jsonSchemaPool) throws Exception {
        SecuritySchemePart describedBy = securityScheme.describedBy();
        List<Parameter> queryParameters = getParameterList(describedBy.queryParameters(), QUERY, jsonSchemaPool);
        List<Parameter> headers = getParameterList(describedBy.headers(), HEADER, jsonSchemaPool);

        return new PassThroughScheme(queryParameters, headers);
    }

    private static JwtAuthenticationScheme buildJwtAuthenticationScheme(SecurityScheme securityScheme, JsonSchemaPool jsonSchemaPool) throws Exception {
        TypeInstance root = securityScheme.describedBy()
                .annotations()
                .stream()
                .map(AnnotationRef::structuredValue)
                .findFirst()
                .get();
        List<Parameter> headers = new ArrayList<>();
        Optional.ofNullable(get(root, "body.jwt.headers.alg"))
                .map(alg -> parse(alg, "alg", "Encryption Algorithm", "Algorithm used to sign and encrypt the JWT tokens."))
                .ifPresent(headers::add);
        Optional.ofNullable(get(root, "body.jwt.headers.typ"))
                .map(typ -> parse(typ, "typ", "Token Media Type", "Header Parameter defined by JWT applications to declare the media type of this complete JWT."))
                .ifPresent(headers::add);
        Optional.ofNullable(get(root, "body.jwt.headers.cty"))
                .map(cty -> parse(cty, "cty", "Token Content Type", "Header Parameter defined by JWT applications to convey structural information about the token."))
                .ifPresent(headers::add);
        getProperty(get(root, "body.jwt.headers"), "custom")
                .map(TypeInstanceProperty::values)
                .ifPresent(customHeaders -> customHeaders.stream()
                        .map(RamlParserSecuritySchemeFactory::getCustomParameter)
                        .forEach(headers::add));
        List<Parameter> claims = new ArrayList<>();
        List<Parameter> accessTokenRequestParameters = new ArrayList<>();
        Optional.ofNullable(get(root, "body.jwt.claims.iss"))
                .map(iss -> parse(iss, "iss", "Issuer", "The \"iss\" (issuer) claim identifies the principal that issued the JWT.  The processing of this claim is generally application specific. The \"iss\" value is a case-sensitive string containing a StringOrURI value."))
                .ifPresent(claims::add);
        Optional.ofNullable(get(root, "body.jwt.claims.sub"))
                .map(aud -> parse(aud, "sub", "Subject", "The \"sub\" (subject) claim identifies the principal that is the subject of the JWT."))
                .ifPresent(claims::add);
        Optional.ofNullable(get(root, "body.jwt.claims.aud"))
                .map(aud -> parse(aud, "aud", "Audience", "The \"aud\" (audience) claim identifies the recipients that the JWT is intended for.  Each principal intended to process the JWT MUST identify itself with a value in the audience claim.  If the principal processing the claim does not identify itself with a value in the \"aud\" claim when this claim is present, then the JWT MUST be rejected."))
                .ifPresent(claims::add);
        Optional.ofNullable(get(root, "body.jwt.claims.jti"))
                .map(jti -> parse(jti, "jti", "JWT ID", "The \"jti\" (JWT ID) claim provides a unique identifier for the JWT. The identifier value MUST be assigned in a manner that ensures that there is a negligible probability that the same value will be accidentally assigned to a different data object."))
                .ifPresent(claims::add);
        Optional.ofNullable(get(root, "body.jwt.claims.exp"))
                .map(exp -> parse(exp, "exp", "Expiration Time", "The \"exp\" (expiration time) claim identifies the expiration time on or after which the JWT MUST NOT be accepted for processing."))
                .ifPresent(claims::add);
        Optional.ofNullable(get(root, "body.jwt.claims.iat"))
                .map(iat -> parse(iat, "iat", "Issued At", "The \"iat\" (issued at) claim identifies the time at which the JWT was issued.  This claim can be used to determine the age of the JWT.  Its value MUST be a number containing a NumericDate value."))
                .ifPresent(claims::add);
        Optional.ofNullable(get(root, "body.jwt.claims.nbf"))
                .map(nbf -> parse(nbf, "nbf", "Not Before", "The \"nbf\" (not before) claim identifies the time before which the JWT MUST NOT be accepted for processing.  The processing of the \"nbf\" claim requires that the current date/time MUST be after or equal to the not-before date/time listed in the \"nbf\" claim."))
                .ifPresent(claims::add);
        getProperty(get(root, "body.jwt.claims"), "custom")
                .map(TypeInstanceProperty::values)
                .ifPresent(customClaims -> customClaims.stream()
                        .map(RamlParserSecuritySchemeFactory::getCustomParameter)
                        .forEach(claims::add));
        getProperty(get(root, "body"), "parameters")
                .map(TypeInstanceProperty::values)
                .ifPresent(customClaims -> customClaims.stream()
                        .map(RamlParserSecuritySchemeFactory::getCustomParameter)
                        .forEach(accessTokenRequestParameters::add));
        Optional.ofNullable(get(root, "url")).map(url -> getSecuritySchemeParameter("url",
                "URL",
                STRING,
                "Access Token retrieval URL.",
                null,
                null,
                false,
                false,
                false,
                url.value().toString()))
                .ifPresent(accessTokenRequestParameters::add);

        return new JwtAuthenticationScheme(headers, claims, accessTokenRequestParameters);
    }

    private static Parameter parse(TypeInstance instance, String name, String displayName, String description) {
        return getSecuritySchemeParameter(name,
                displayName,
                STRING,
                description,
                getStringValue(instance, "defaultValue", null),
                null,
                getBooleanValue(instance, "required", FALSE),
                false,
                getBooleanValue(instance, "generated", FALSE),
                getProperty(instance, "restrictedValues")
                        .map(TypeInstanceProperty::values)
                        .map(restrictedValueInstances -> restrictedValueInstances.stream()
                                .map(TypeInstance::value)
                                .map(String.class::cast)
                                .collect(toList())
                                .toArray(new String[]{}))
                        .orElse(new String[]{}));
    }

    private static Parameter getCustomParameter(TypeInstance instance) {
        String name = getStringValue(instance, "name", null);
        return getSecuritySchemeParameter(name,
                getStringValue(instance, "displayName", name),
                STRING,
                getStringValue(instance, "description", null),
                getStringValue(instance, "defaultValue", null),
                null,
                getBooleanValue(instance, "required", FALSE),
                false,
                getProperty(instance, "restrictedValues")
                        .map(TypeInstanceProperty::values)
                        .map(restrictedValueInstances -> restrictedValueInstances.stream()
                                .map(TypeInstance::value)
                                .map(String.class::cast)
                                .collect(toList())
                                .toArray(new String[]{}))
                        .orElse(new String[]{}));
    }

    private static String getStringValue(TypeInstance instance, String key, String option) {
        return getValue(instance, String.class, key, option);
    }

    private static Boolean getBooleanValue(TypeInstance instance, String key, Boolean option) {
        return getValue(instance, Boolean.class, key, option);
    }

    private static <T> T getValue(TypeInstance instance, Class<T> type, String key, T option) {
        return getProperty(instance, key)
                .map(TypeInstanceProperty::value)
                .map(TypeInstance::value)
                .map(type::cast)
                .orElse(option);
    }

    private static TypeInstance get(TypeInstance instance, String key) {
        return get(instance, Stream.of(key.split("\\.")).collect(toList()));
    }

    private static TypeInstance get(TypeInstance instance, List<String> path) {
        if (path.isEmpty()) {
            return instance;
        } else {
            return getProperty(instance, path.remove(0)).map(TypeInstanceProperty::value)
                    .map(property -> get(property, path))
                    .orElse(null);
        }
    }

    private static Optional<TypeInstanceProperty> getProperty(TypeInstance instance, String propertyName) {
        return instance.properties().stream()
                .filter(property -> property.name().equals(propertyName))
                .findFirst();
    }

    private static APISecurityScheme buildCustomAuthenticationSecurityScheme(SecurityScheme securityScheme, JsonSchemaPool jsonSchemaPool) throws Exception {
        SecuritySchemePart describedBy = securityScheme.describedBy();
        List<Parameter> queryParameters = getParameterList(describedBy.queryParameters(), QUERY, jsonSchemaPool);
        List<Parameter> headers = getParameterList(describedBy.headers(), HEADER, jsonSchemaPool);

        return new CustomAuthenticationScheme(queryParameters, headers);
    }

    private static OAuth2AuthorizationCodeScheme buildOAuth2AuthorizationCodeSecurityScheme(SecurityScheme securityScheme){
        String authorizationUri = getValueFromAnnotableString( securityScheme.settings().authorizationUri());
        String accessTokenUri = getValueFromAnnotableString( securityScheme.settings().accessTokenUri());
        List<String> authorizationGrants = Arrays.asList(new String[] { OAUTH2_GRANT_AUTHORIZATION_CODE });
        List<String> scopes = securityScheme.settings().scopes();
        String renewTokenExpression = getAnnotatedRenewTokenExpression(securityScheme);

        return new OAuth2AuthorizationCodeScheme(authorizationUri, accessTokenUri, authorizationGrants, scopes, renewTokenExpression);
    }

    private static OAuth2ClientCredentialsScheme buildOAuth2ClientCredentialsSecurityScheme(SecurityScheme securityScheme){
        String authorizationUri = getValueFromAnnotableString( securityScheme.settings().authorizationUri());
        String accessTokenUri = getValueFromAnnotableString( securityScheme.settings().accessTokenUri());
        List<String> authorizationGrants = Arrays.asList(new String[] { OAUTH2_GRANT_CLIENT_CREDENTIALS });
        List<String> scopes = securityScheme.settings().scopes();
        String renewTokenExpression = getAnnotatedRenewTokenExpression(securityScheme);

        return new OAuth2ClientCredentialsScheme(authorizationUri, accessTokenUri, authorizationGrants, scopes, renewTokenExpression);
    }

    private static String listSchemes(List<SecurityScheme> securitySchemes){
        StringBuilder builder = new StringBuilder();
        for(SecurityScheme securityScheme : securitySchemes){
            builder.append(securityScheme.type());
            builder.append("<");
            builder.append(securityScheme.name());
            builder.append(">");
            if(securityScheme.type().equals(APISecurityScheme.OAUTH2)){
                builder.append(" :");
                for(String grant : securityScheme.settings().authorizationGrants()){
                    builder.append(" ");
                    builder.append(grant);
                }
            }
            builder.append(". ");
        }

        return builder.toString();
    }

}
