/*
 * (c) 2003-2020 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package com.mulesoft.connectivity.rest.sdk.internal.templating.sdk;

import com.mulesoft.connectivity.rest.sdk.api.RestSdkRunConfiguration;
import com.mulesoft.connectivity.rest.sdk.exception.TemplatingException;
import com.mulesoft.connectivity.rest.sdk.internal.connectormodel.ConnectorModel;
import com.mulesoft.connectivity.rest.sdk.internal.connectormodel.security.ConnectorSecurityScheme;
import com.mulesoft.connectivity.rest.sdk.internal.connectormodel.security.OAuth2AuthorizationCodeScheme;
import com.mulesoft.connectivity.rest.sdk.internal.connectormodel.security.OAuth2ClientCredentialsScheme;
import com.mulesoft.connectivity.rest.sdk.internal.connectormodel.security.OAuth2Scheme;
import com.mulesoft.connectivity.rest.sdk.internal.connectormodel.security.TestConnectionConfig;
import com.mulesoft.connectivity.rest.sdk.internal.connectormodel.security.TestConnectionValidationConfig;
import com.mulesoft.connectivity.rest.sdk.internal.connectormodel.uri.BaseUri;
import com.mulesoft.connectivity.rest.sdk.internal.templating.JavaTemplateEntity;
import com.mulesoft.connectivity.rest.commons.api.connection.BaseConnectionProvider;
import com.mulesoft.connectivity.rest.commons.api.connection.BasicAuthenticationConnectionProvider;
import com.mulesoft.connectivity.rest.commons.api.connection.DigestConnectionProvider;
import com.mulesoft.connectivity.rest.commons.api.connection.MandatoryTlsParameterGroup;
import com.mulesoft.connectivity.rest.commons.api.connection.OptionalTlsParameterGroup;
import com.mulesoft.connectivity.rest.commons.api.connection.RestConnection;
import com.mulesoft.connectivity.rest.commons.api.connection.TlsParameterGroup;
import com.mulesoft.connectivity.rest.commons.api.connection.oauth.BaseAuthorizationCodeConnectionProvider;
import com.mulesoft.connectivity.rest.commons.api.connection.oauth.BaseClientCredentialsConnectionProvider;
import com.mulesoft.connectivity.rest.commons.api.connection.validation.ConnectionValidationSettings;
import com.mulesoft.connectivity.rest.commons.internal.util.RestSdkUtils;

import org.mule.runtime.api.connection.ConnectionValidationResult;
import org.mule.runtime.api.el.ExpressionLanguage;
import org.mule.runtime.api.metadata.MediaType;
import org.mule.runtime.api.util.MultiMap;
import org.mule.runtime.extension.api.annotation.Alias;
import org.mule.runtime.extension.api.annotation.connectivity.oauth.AuthorizationCode;
import org.mule.runtime.extension.api.annotation.connectivity.oauth.ClientCredentials;
import org.mule.runtime.extension.api.annotation.param.Optional;
import org.mule.runtime.extension.api.annotation.param.Parameter;
import org.mule.runtime.extension.api.annotation.param.ParameterGroup;
import org.mule.runtime.extension.api.annotation.param.display.DisplayName;
import org.mule.runtime.extension.api.annotation.param.display.Summary;
import org.mule.runtime.extension.api.connectivity.NoConnectivityTest;
import org.mule.runtime.http.api.HttpConstants.Method;
import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.client.auth.HttpAuthentication;

import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.StringJoiner;

import static com.google.common.base.CaseFormat.LOWER_CAMEL;
import static com.mulesoft.connectivity.rest.sdk.internal.connectormodel.security.ConnectorSecurityScheme.SecuritySchemeType.BASIC;
import static com.mulesoft.connectivity.rest.sdk.internal.connectormodel.security.ConnectorSecurityScheme.SecuritySchemeType.DIGEST_AUTHENTICATION;
import static com.mulesoft.connectivity.rest.sdk.internal.connectormodel.security.ConnectorSecurityScheme.SecuritySchemeType.OAUTH2;
import static com.mulesoft.connectivity.rest.sdk.internal.connectormodel.security.ConnectorSecurityScheme.SecuritySchemeType.PASS_THROUGH;
import static com.mulesoft.connectivity.rest.sdk.internal.connectormodel.security.ConnectorSecurityScheme.SecuritySchemeType.UNSECURED;
import static com.mulesoft.connectivity.rest.sdk.internal.connectormodel.security.OAuth2Scheme.GrantType.AUTHORIZATION_CODE;
import static com.mulesoft.connectivity.rest.sdk.internal.connectormodel.security.OAuth2Scheme.GrantType.CLIENT_CREDENTIALS;
import static javax.lang.model.element.Modifier.PRIVATE;
import static javax.lang.model.element.Modifier.PROTECTED;
import static javax.lang.model.element.Modifier.PUBLIC;
import static org.apache.commons.lang.WordUtils.capitalize;
import static org.apache.commons.lang3.StringUtils.defaultIfEmpty;
import static org.apache.commons.lang3.StringUtils.isNotBlank;
import static com.mulesoft.connectivity.rest.sdk.internal.util.JavaUtils.getJavaUpperCamelNameFromXml;
import static com.mulesoft.connectivity.rest.sdk.internal.webapi.util.XmlUtils.getXmlName;

import com.squareup.javapoet.AnnotationSpec;
import com.squareup.javapoet.CodeBlock;
import com.squareup.javapoet.FieldSpec;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.ParameterizedTypeName;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;

import javax.inject.Inject;

public class SdkConnectionProvider extends JavaTemplateEntity {

  private final static String JAVA_DOC_BASE_URI = "@return the base uri of the REST API being consumed";

  private static final String CREATE_CONNECTION_METHOD_NAME = "createConnection";
  private static final String CREATE_CONNECTION_HTTP_CLIENT_PARAMETER_NAME = "httpClient";
  private static final String CREATE_CONNECTION_AUTHENTICATION_PARAMETER_NAME = "authentication";
  private static final String CREATE_CONNECTION_QUERY_PARAMS_PARAMETER_NAME = "defaultQueryParams";
  private static final String CREATE_CONNECTION_HEADERS_PARAMETER_NAME = "defaultHeaders";

  private static final String VALIDATE_CONNECTION_METHOD_NAME = "validate";
  private static final String VALIDATE_CONNECTION_CONNECTION_PARAMETER_NAME = "restConnection";
  private static final String VALIDATE_CONNECTION_SETTINGS_VAR_NAME = "settings";

  private static final String EXPRESSION_LANGUAGE_FIELD_NAME = "expressionLanguage";

  public static final String ACCESS_TOKEN_URL = "accessTokenUrl";
  public static final String AUTHORIZATION_URL = "authorizationUrl";
  public static final String DEFAULT_SCOPES = "defaultScopes";
  public static final String TOKEN_URL = "tokenUrl";

  private final ConnectorSecurityScheme securityScheme;

  private final List<SdkParameter> headers;
  private final List<SdkParameter> queryParameters;

  public SdkConnectionProvider(Path outputDir, ConnectorModel connectorModel, SdkConnector sdkConnector,
                               ConnectorSecurityScheme securityScheme, RestSdkRunConfiguration runConfiguration)
      throws TemplatingException {
    super(outputDir, connectorModel, runConfiguration);
    this.securityScheme = securityScheme;

    headers = buildSdkParameters(securityScheme.getHeaders(), outputDir, connectorModel, sdkConnector);
    queryParameters = buildSdkParameters(securityScheme.getQueryParameters(), outputDir, connectorModel, sdkConnector);
  }

  private List<SdkParameter> buildSdkParameters(
                                                List<com.mulesoft.connectivity.rest.sdk.internal.connectormodel.parameter.Parameter> parameters,
                                                Path outputDir, ConnectorModel connectorModel, SdkConnector sdkConnector)
      throws TemplatingException {

    final List<SdkParameter> list = new ArrayList<>();
    for (com.mulesoft.connectivity.rest.sdk.internal.connectormodel.parameter.Parameter param : parameters) {
      list.add(new SdkParameter(outputDir, connectorModel, sdkConnector, getJavaClassName(), param, this, runConfiguration));
    }
    return list;
  }

  public String getPackage() {
    return connectorModel.getBasePackage() + ".internal.connection.provider";
  }

  private String getConnectionProviderXmlName() {
    return getXmlName(securityScheme.getDisplayName());
  }

  private String getConnectionProviderDisplayName() {
    return capitalize(getConnectionProviderXmlName().replace('-', ' ')) + " Connection Provider";
  }

  public String getJavaClassName() {
    return getJavaUpperCamelNameFromXml(getConnectionProviderXmlName()) + "ConnectionProvider";
  }

  @Override
  public void applyTemplates() throws TemplatingException {
    generateConnectionProviderClass();
  }

  private void generateConnectionProviderClass() throws TemplatingException {

    FieldSpec baseUriField = generateBaseUriField();

    TypeSpec.Builder connectionProviderClassBuilder =
        TypeSpec
            .classBuilder(getJavaClassName())
            .addModifiers(PUBLIC)
            .superclass(getConnectionProviderClass())
            .addAnnotation(generateAliasAnnotation())
            .addAnnotation(generateDisplayNameAnnotation())
            .addField(baseUriField)
            .addMethod(generateGetter(baseUriField, LOWER_CAMEL).addAnnotation(Override.class)
                .addJavadoc(CodeBlock.builder().add(JAVA_DOC_BASE_URI).add("\n").build()).build());

    addOAuth(connectionProviderClassBuilder);
    addTls(connectionProviderClassBuilder);
    addCreateConnectionOverrideMethod(connectionProviderClassBuilder);

    for (SdkParameter header : headers) {
      addParameterField(connectionProviderClassBuilder, header);
    }

    for (SdkParameter queryParam : queryParameters) {
      addParameterField(connectionProviderClassBuilder, queryParam);
    }

    if (shouldTestConnectivity()) {
      addValidateConnectionMethod(connectionProviderClassBuilder);
    } else {
      connectionProviderClassBuilder.addSuperinterface(NoConnectivityTest.class);
    }

    writeClassToFile(connectionProviderClassBuilder.build(), getPackage());
  }

  private boolean shouldTestConnectivity() {
    final boolean hasTestConnectionDefinition = securityScheme.getTestConnectionConfig() != null;
    final boolean isOauth = securityScheme instanceof OAuth2Scheme;
    final boolean isAuthorizationCode = isOauth && ((OAuth2Scheme) securityScheme).getGrantType().equals(AUTHORIZATION_CODE);

    return hasTestConnectionDefinition && !isAuthorizationCode;
  }

  private void addParameterField(TypeSpec.Builder connectionProviderClassBuilder, SdkParameter parameter) {

    FieldSpec.Builder fieldBuilder = securityScheme.getSchemeType().equals(OAUTH2)
        ? parameter.generateOAuthParameterField() : parameter.generateParameterField();

    fieldBuilder
        .addModifiers(PRIVATE)
        .addJavadoc(generateSdkParameterJavaDoc(parameter));

    connectionProviderClassBuilder.addField(fieldBuilder.build());
  }

  private void addCreateConnectionOverrideMethod(TypeSpec.Builder connectionProviderClassBuilder) {
    if ((!headers.isEmpty() || !queryParameters.isEmpty())
        && !securityScheme.getSchemeType().equals(OAUTH2)) {
      MethodSpec createConnectionMethod = MethodSpec.methodBuilder(CREATE_CONNECTION_METHOD_NAME)
          .returns(TypeName.get(RestConnection.class))
          .addModifiers(PROTECTED)
          .addAnnotation(Override.class)
          .addParameter(HttpClient.class, CREATE_CONNECTION_HTTP_CLIENT_PARAMETER_NAME)
          .addParameter(HttpAuthentication.class, CREATE_CONNECTION_AUTHENTICATION_PARAMETER_NAME)
          .addParameter(getStringMultiMapTypeName(), CREATE_CONNECTION_QUERY_PARAMS_PARAMETER_NAME)
          .addParameter(getStringMultiMapTypeName(), CREATE_CONNECTION_HEADERS_PARAMETER_NAME)
          .addCode(generateCreateConnectionMethodBody())
          .build();

      connectionProviderClassBuilder.addMethod(createConnectionMethod);
    }
  }

  private CodeBlock generateCreateConnectionMethodBody() {
    CodeBlock.Builder methodBody = CodeBlock.builder();

    addCustomParameters(methodBody, queryParameters, CREATE_CONNECTION_QUERY_PARAMS_PARAMETER_NAME);
    addCustomParameters(methodBody, headers, CREATE_CONNECTION_HEADERS_PARAMETER_NAME);

    methodBody.addStatement("return super.createConnection($L, $L, $L, $L)",
                            CREATE_CONNECTION_HTTP_CLIENT_PARAMETER_NAME,
                            CREATE_CONNECTION_AUTHENTICATION_PARAMETER_NAME,
                            CREATE_CONNECTION_QUERY_PARAMS_PARAMETER_NAME,
                            CREATE_CONNECTION_HEADERS_PARAMETER_NAME);

    return methodBody.build();
  }

  private void addCustomParameters(CodeBlock.Builder methodBody, List<SdkParameter> parameters, String parametersParameterName) {
    if (!parameters.isEmpty()) {
      methodBody
          .beginControlFlow("if($L == null)", parametersParameterName)
          .addStatement("$L = new $T.StringMultiMap()", parametersParameterName, MultiMap.class)
          .endControlFlow();

      for (SdkParameter header : parameters) {
        methodBody
            .beginControlFlow("if($T.isNotBlank($L))", RestSdkUtils.class, header.getJavaName())
            .addStatement("$L.put($S, $L)",
                          parametersParameterName,
                          header.getExternalName(),
                          header.getJavaName())
            .endControlFlow();
      }
    }
  }

  private void addValidateConnectionMethod(TypeSpec.Builder connectionProviderClassBuilder) {
    FieldSpec expressionLanguageField = FieldSpec
        .builder(ExpressionLanguage.class, EXPRESSION_LANGUAGE_FIELD_NAME, PRIVATE)
        .addAnnotation(AnnotationSpec.builder(Inject.class).build())
        .build();

    connectionProviderClassBuilder.addField(expressionLanguageField);

    MethodSpec validateConnectionMethod = MethodSpec.methodBuilder(VALIDATE_CONNECTION_METHOD_NAME)
        .returns(TypeName.get(ConnectionValidationResult.class))
        .addModifiers(PUBLIC)
        .addAnnotation(Override.class)
        .addParameter(RestConnection.class, VALIDATE_CONNECTION_CONNECTION_PARAMETER_NAME)
        .addCode(generateValidateConnectionMethodBody())
        .build();

    connectionProviderClassBuilder.addMethod(validateConnectionMethod);
  }

  private CodeBlock generateValidateConnectionMethodBody() {
    CodeBlock.Builder methodBody = CodeBlock.builder();

    TestConnectionConfig testConnectionConfig = securityScheme.getTestConnectionConfig();

    String path = testConnectionConfig.getPath();
    String method = testConnectionConfig.getMethod() != null ? testConnectionConfig.getMethod().name() : null;
    Set<String> validStatusCodes = testConnectionConfig.getValidStatusCodes();

    methodBody.add("$1T $2L = $1T.builder($3S, $4L)",
                   ConnectionValidationSettings.class,
                   VALIDATE_CONNECTION_SETTINGS_VAR_NAME,
                   path,
                   EXPRESSION_LANGUAGE_FIELD_NAME);

    if (isNotBlank(method)) {
      methodBody.add(".httpMethod($T.$L)", Method.class, method);
    }

    if (!validStatusCodes.isEmpty()) {
      StringJoiner validStatusCodesJoiner = new StringJoiner(", ");
      validStatusCodes.forEach(validStatusCodesJoiner::add);
      methodBody.add(".validStatusCodes($L)", validStatusCodesJoiner);
    }

    for (TestConnectionValidationConfig validationConfig : testConnectionConfig.getValidations()) {
      if (isNotBlank(validationConfig.getValidationExpression())) {
        if (isNotBlank(validationConfig.getErrorTemplateExpression())) {
          methodBody.add(".addValidation($S, $S)", validationConfig.getValidationExpression(),
                         validationConfig.getErrorTemplateExpression());
        } else {
          methodBody.add(".addValidation($S)", validationConfig.getValidationExpression());
        }
      }
    }

    if (testConnectionConfig.getMediaType() != null) {
      methodBody.add(".responseMediaType($T.parse($S))", MediaType.class, testConnectionConfig.getMediaType().toString());
    }

    methodBody.add(".build();");

    methodBody.addStatement("return $L($L, $L)",
                            VALIDATE_CONNECTION_METHOD_NAME,
                            VALIDATE_CONNECTION_CONNECTION_PARAMETER_NAME,
                            VALIDATE_CONNECTION_SETTINGS_VAR_NAME);

    return methodBody.build();
  }

  private TypeName getStringMultiMapTypeName() {
    return ParameterizedTypeName.get(MultiMap.class, String.class, String.class);
  }

  private void addTls(TypeSpec.Builder connectionProviderClassBuilder) {
    if (connectorModel.supportsHTTPS()) {
      FieldSpec tlsField = generateTlsField();
      CodeBlock.Builder tlsJavaDocBuilder = CodeBlock.builder();

      if (connectorModel.supportsHTTP()) {
        tlsJavaDocBuilder
            .add("\n{@link $L} that configures TLS and allows to switch between HTTP and HTTPS protocols.\n\n",
                 TlsParameterGroup.class.getSimpleName());
      } else {
        tlsJavaDocBuilder.add("\n{@link $L} that configures TLS for this connection.\n\n",
                              TlsParameterGroup.class.getSimpleName());
      }

      tlsJavaDocBuilder.add("@return an optional {@link $L}", TlsParameterGroup.class.getSimpleName()).build();

      connectionProviderClassBuilder
          .addField(tlsField)
          .addMethod(
                     generateOptionalGetter(tlsField, TlsParameterGroup.class, LOWER_CAMEL)
                         .addAnnotation(Override.class)
                         .addJavadoc(tlsJavaDocBuilder.build())
                         .build());
    }
  }

  private void addOAuth(TypeSpec.Builder connectionProviderClassBuilder) throws TemplatingException {
    if (securityScheme.getSchemeType().equals(OAUTH2)) {
      AnnotationSpec oAuthAnnotation = null;

      if (getConnectionProviderClass().equals(BaseClientCredentialsConnectionProvider.class)) {
        OAuth2ClientCredentialsScheme clientCredentials = (OAuth2ClientCredentialsScheme) securityScheme;
        oAuthAnnotation = AnnotationSpec
            .builder(ClientCredentials.class)
            .addMember(TOKEN_URL, "$S", clientCredentials.getAccessTokenUri())
            .addMember(DEFAULT_SCOPES, "$S", clientCredentials.getScopes())
            .build();

      } else if (getConnectionProviderClass().equals(BaseAuthorizationCodeConnectionProvider.class)) {
        OAuth2AuthorizationCodeScheme authorizationCode = (OAuth2AuthorizationCodeScheme) securityScheme;
        oAuthAnnotation = AnnotationSpec
            .builder(AuthorizationCode.class)
            .addMember(ACCESS_TOKEN_URL, "$S", authorizationCode.getAccessTokenUri())
            .addMember(AUTHORIZATION_URL, "$S", authorizationCode.getAuthorizationUri())
            .addMember(DEFAULT_SCOPES, "$S", authorizationCode.getScopes())
            .build();
      }

      if (oAuthAnnotation != null) {
        connectionProviderClassBuilder.addAnnotation(oAuthAnnotation);
      }
    }
  }

  private CodeBlock generateSdkParameterJavaDoc(SdkParameter sdkParameter) {
    return CodeBlock.builder()
        .add("$L\n", defaultIfEmpty(sdkParameter.getDescription(), sdkParameter.getDisplayName())).build();
  }

  private Class<? extends BaseConnectionProvider> getConnectionProviderClass() throws TemplatingException {
    ConnectorSecurityScheme.SecuritySchemeType securitySchemeType = securityScheme.getSchemeType();

    if (securitySchemeType.equals(BASIC)) {
      return BasicAuthenticationConnectionProvider.class;
    } else if (securitySchemeType.equals(DIGEST_AUTHENTICATION)) {
      return DigestConnectionProvider.class;
    } else if (securitySchemeType.equals(PASS_THROUGH)) {
      return BaseConnectionProvider.class;
    } else if (securitySchemeType.equals(UNSECURED)) {
      return BaseConnectionProvider.class;
    } else if (securitySchemeType.equals(OAUTH2)) {
      if (((OAuth2Scheme) securityScheme).getGrantType().equals(AUTHORIZATION_CODE)) {
        return BaseAuthorizationCodeConnectionProvider.class;
      } else if (((OAuth2Scheme) securityScheme).getGrantType().equals(CLIENT_CREDENTIALS)) {
        return BaseClientCredentialsConnectionProvider.class;
      }
    }

    throw new TemplatingException("Connection Provider not available for " + securitySchemeType);
  }

  private AnnotationSpec generateDisplayNameAnnotation() {
    return AnnotationSpec
        .builder(DisplayName.class)
        .addMember(VALUE_MEMBER, "$S", getConnectionProviderDisplayName())
        .build();
  }

  private AnnotationSpec generateAliasAnnotation() {
    return AnnotationSpec
        .builder(Alias.class)
        .addMember(VALUE_MEMBER, "$S", getConnectionProviderXmlName())
        .build();
  }

  private FieldSpec generateBaseUriField() {
    FieldSpec.Builder baseUriFieldSpec = FieldSpec
        .builder(String.class, "baseUri", PRIVATE)
        .addAnnotation(AnnotationSpec.builder(DisplayName.class)
            .addMember(VALUE_MEMBER, "$S", "Base Uri").build())
        .addAnnotation(AnnotationSpec.builder(Summary.class)
            .addMember(VALUE_MEMBER, "$S", connectorModel.getBaseUri().getType().getDescription()).build())
        .addJavadoc(CodeBlock.builder().add(JAVA_DOC_BASE_URI).build());
    if (connectorModel.getBaseUri().isParameterizedBaseUri()) {
      baseUriFieldSpec.addAnnotation(Parameter.class)
          .addAnnotation(buildBaseUriOptionalAnnotation(connectorModel.getBaseUri()));
    } else {
      baseUriFieldSpec.initializer("\"" + connectorModel.getBaseUri().getUri() + "\"");
    }

    return baseUriFieldSpec.build();
  }

  private AnnotationSpec buildBaseUriOptionalAnnotation(BaseUri baseUri) {
    AnnotationSpec.Builder optionalAnnotationBuilder = AnnotationSpec.builder(Optional.class);

    if (isNotBlank(baseUri.getUri())) {
      optionalAnnotationBuilder.addMember("defaultValue", "$S", connectorModel.getBaseUri().getUri());
    }

    return optionalAnnotationBuilder.build();
  }

  private FieldSpec generateTlsField() {
    Class<?> tlsClass = connectorModel.supportsHTTP() ? OptionalTlsParameterGroup.class : MandatoryTlsParameterGroup.class;
    return FieldSpec
        .builder(tlsClass, "tlsConfig", PRIVATE)
        .addJavadoc(CodeBlock.builder()
            .add("{@link $L} references to a TLS config element. This will enable HTTPS for this config.\n",
                 tlsClass.getSimpleName())
            .build())
        .addAnnotation(AnnotationSpec.builder(ParameterGroup.class).addMember(NAME_MEMBER, "$S", "tls").build())
        .build();
  }

}
