/*
 * (c) 2003-2018 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package org.mule.connectivity.restconnect.internal.templating.sdk;

import org.mule.connectivity.restconnect.exception.TemplatingException;
import org.mule.connectivity.restconnect.internal.connectormodel.ConnectorModel;
import org.mule.connectivity.restconnect.internal.connectormodel.security.ConnectorSecurityScheme;
import org.mule.connectivity.restconnect.internal.connectormodel.security.OAuth2AuthorizationCodeScheme;
import org.mule.connectivity.restconnect.internal.connectormodel.security.OAuth2ClientCredentialsScheme;
import org.mule.connectivity.restconnect.internal.connectormodel.security.OAuth2Scheme;
import org.mule.connectivity.restconnect.internal.connectormodel.uri.BaseUri;
import org.mule.connectivity.restconnect.internal.templating.JavaTemplateEntity;
import org.mule.connectors.restconnect.commons.api.connection.BaseConnectionProvider;
import org.mule.connectors.restconnect.commons.api.connection.BasicAuthenticationConnectionProvider;
import org.mule.connectors.restconnect.commons.api.connection.DigestConnectionProvider;
import org.mule.connectors.restconnect.commons.api.connection.MandatoryTlsParameterGroup;
import org.mule.connectors.restconnect.commons.api.connection.OptionalTlsParameterGroup;
import org.mule.connectors.restconnect.commons.api.connection.RestConnection;
import org.mule.connectors.restconnect.commons.api.connection.TlsParameterGroup;
import org.mule.connectors.restconnect.commons.api.connection.oauth.BaseAuthorizationCodeConnectionProvider;
import org.mule.connectors.restconnect.commons.api.connection.oauth.BaseClientCredentialsConnectionProvider;
import org.mule.connectors.restconnect.commons.internal.util.RestConnectUtils;
import org.mule.runtime.api.util.MultiMap;
import org.mule.runtime.extension.api.annotation.Alias;
import org.mule.runtime.extension.api.annotation.connectivity.oauth.AuthorizationCode;
import org.mule.runtime.extension.api.annotation.connectivity.oauth.ClientCredentials;
import org.mule.runtime.extension.api.annotation.param.Optional;
import org.mule.runtime.extension.api.annotation.param.Parameter;
import org.mule.runtime.extension.api.annotation.param.ParameterGroup;
import org.mule.runtime.extension.api.annotation.param.display.DisplayName;
import org.mule.runtime.extension.api.annotation.param.display.Summary;
import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.client.auth.HttpAuthentication;

import java.nio.file.Path;
import java.util.List;

import static java.util.stream.Collectors.toList;
import static javax.lang.model.element.Modifier.PROTECTED;
import static org.apache.commons.lang.WordUtils.capitalize;
import static org.apache.commons.lang3.StringUtils.defaultIfEmpty;
import static org.apache.commons.lang3.StringUtils.isNotBlank;
import static org.mule.connectivity.restconnect.internal.connectormodel.security.ConnectorSecurityScheme.SecuritySchemeType.OAUTH2;
import static org.mule.connectivity.restconnect.internal.util.JavaUtils.getJavaUpperCamelNameFromXml;
import static org.mule.connectivity.restconnect.internal.webapi.util.XmlUtils.getXmlName;

import com.google.common.base.CaseFormat;
import com.squareup.javapoet.AnnotationSpec;
import com.squareup.javapoet.CodeBlock;
import com.squareup.javapoet.FieldSpec;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.ParameterizedTypeName;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;

import javax.lang.model.element.Modifier;

public class SdkConnectionProvider extends JavaTemplateEntity {

  private final static String JAVA_DOC_BASE_URI = "@return the base uri of the REST API being consumed";

  private final static String CREATE_CONNECTION_METHOD_NAME = "createConnection";
  private final static String CREATE_CONNECTION_HTTP_CLIENT_PARAMETER_NAME = "httpClient";
  private final static String CREATE_CONNECTION_AUTHENTICATION_PARAMETER_NAME = "authentication";
  private final static String CREATE_CONNECTION_QUERY_PARAMS_PARAMETER_NAME = "defaultQueryParams";
  private final static String CREATE_CONNECTION_HEADERS_PARAMETER_NAME = "defaultHeaders";

  private final ConnectorSecurityScheme securityScheme;

  private final List<SdkParameter> headers;
  private final List<SdkParameter> queryParameters;

  public SdkConnectionProvider(Path outputDir, ConnectorModel connectorModel, SdkConnector sdkConnector,
                               ConnectorSecurityScheme securityScheme) {
    super(outputDir, connectorModel);
    this.securityScheme = securityScheme;

    this.headers = securityScheme.getHeaders().stream()
        .map(parameter -> new SdkParameter(outputDir, connectorModel, sdkConnector, getJavaClassName(), parameter))
        .collect(toList());


    this.queryParameters = securityScheme.getQueryParameters().stream()
        .map(parameter -> new SdkParameter(outputDir, connectorModel, sdkConnector, getJavaClassName(), parameter))
        .collect(toList());
  }

  public String getPackage() {
    return connectorModel.getBasePackage() + ".internal.connection.provider";
  }

  private String getConnectionProviderXmlName() {
    return getXmlName(securityScheme.getDisplayName());
  }

  private String getConnectionProviderDisplayName() {
    return capitalize(getConnectionProviderXmlName().replace('-', ' ')) + " Connection Provider";
  }

  public String getJavaClassName() {
    return getJavaUpperCamelNameFromXml(getConnectionProviderXmlName()) + "ConnectionProvider";
  }

  @Override
  public void applyTemplates() throws TemplatingException {
    generateConnectionProviderClass();
  }

  private void generateConnectionProviderClass() throws TemplatingException {

    FieldSpec baseUriField = generateBaseUriField();

    TypeSpec.Builder connectionProviderClassBuilder =
        TypeSpec
            .classBuilder(getJavaClassName())
            .addModifiers(Modifier.PUBLIC)
            .superclass(getConnectionProviderClass())
            .addAnnotation(generateAliasAnnotation())
            .addAnnotation(generateDisplayNameAnnotation())
            .addField(baseUriField)
            .addMethod(generateGetter(baseUriField, CaseFormat.LOWER_CAMEL).addAnnotation(Override.class)
                .addJavadoc(CodeBlock.builder().add(JAVA_DOC_BASE_URI).add("\n").build()).build());

    addOAuth(connectionProviderClassBuilder);
    addTls(connectionProviderClassBuilder);
    addCreateConnectionOverrideMethod(connectionProviderClassBuilder);

    for (SdkParameter header : headers) {
      addParameterField(connectionProviderClassBuilder, header);
    }

    for (SdkParameter queryParam : queryParameters) {
      addParameterField(connectionProviderClassBuilder, queryParam);
    }

    writeClassToFile(connectionProviderClassBuilder.build(), getPackage());
  }

  private void addParameterField(TypeSpec.Builder connectionProviderClassBuilder, SdkParameter parameter) {

    FieldSpec.Builder fieldBuilder = securityScheme.getSecuritySchemeType().equals(OAUTH2)
        ? parameter.generateOAuthParameterField() : parameter.generateParameterField();

    fieldBuilder
        .addModifiers(Modifier.PRIVATE)
        .addJavadoc(generateSdkParameterJavaDoc(parameter));

    connectionProviderClassBuilder.addField(fieldBuilder.build());
  }

  private void addCreateConnectionOverrideMethod(TypeSpec.Builder connectionProviderClassBuilder) {
    if ((!headers.isEmpty() || !queryParameters.isEmpty())
        && !securityScheme.getSecuritySchemeType().equals(OAUTH2)) {
      MethodSpec createConnectionMethod = MethodSpec.methodBuilder(CREATE_CONNECTION_METHOD_NAME)
          .returns(TypeName.get(RestConnection.class))
          .addModifiers(PROTECTED)
          .addAnnotation(Override.class)
          .addParameter(HttpClient.class, CREATE_CONNECTION_HTTP_CLIENT_PARAMETER_NAME)
          .addParameter(HttpAuthentication.class, CREATE_CONNECTION_AUTHENTICATION_PARAMETER_NAME)
          .addParameter(getStringMultiMapTypeName(), CREATE_CONNECTION_QUERY_PARAMS_PARAMETER_NAME)
          .addParameter(getStringMultiMapTypeName(), CREATE_CONNECTION_HEADERS_PARAMETER_NAME)
          .addCode(generateCreateConnectionMethodBody())
          .build();

      connectionProviderClassBuilder.addMethod(createConnectionMethod);
    }
  }

  private CodeBlock generateCreateConnectionMethodBody() {
    CodeBlock.Builder methodBody = CodeBlock.builder();

    addCustomParameters(methodBody, queryParameters, CREATE_CONNECTION_QUERY_PARAMS_PARAMETER_NAME);
    addCustomParameters(methodBody, headers, CREATE_CONNECTION_HEADERS_PARAMETER_NAME);

    methodBody.addStatement("return super.createConnection($L, $L, $L, $L)",
                            CREATE_CONNECTION_HTTP_CLIENT_PARAMETER_NAME,
                            CREATE_CONNECTION_AUTHENTICATION_PARAMETER_NAME,
                            CREATE_CONNECTION_QUERY_PARAMS_PARAMETER_NAME,
                            CREATE_CONNECTION_HEADERS_PARAMETER_NAME);

    return methodBody.build();
  }

  private void addCustomParameters(CodeBlock.Builder methodBody, List<SdkParameter> parameters, String parametersParameterName) {
    if (!parameters.isEmpty()) {
      methodBody
          .beginControlFlow("if($L == null)", parametersParameterName)
          .addStatement("$L = new $T.StringMultiMap()", parametersParameterName, MultiMap.class)
          .endControlFlow();

      for (SdkParameter header : parameters) {
        methodBody
            .beginControlFlow("if($T.isNotBlank($L))", RestConnectUtils.class, header.getJavaName())
            .addStatement("$L.put($S, $L)",
                          parametersParameterName,
                          header.getExternalName(),
                          header.getJavaName())
            .endControlFlow();
      }
    }
  }

  private TypeName getStringMultiMapTypeName() {
    return ParameterizedTypeName.get(MultiMap.class, String.class, String.class);
  }

  private void addTls(TypeSpec.Builder connectionProviderClassBuilder) {
    if (connectorModel.supportsHTTPS()) {
      FieldSpec tlsField = generateTlsField();
      CodeBlock.Builder tlsJavaDocBuilder = CodeBlock.builder();

      if (connectorModel.supportsHTTP()) {
        tlsJavaDocBuilder
            .add("\n{@link $L} that configures TLS and allows to switch between HTTP and HTTPS protocols.\n\n",
                 TlsParameterGroup.class.getSimpleName());
      } else {
        tlsJavaDocBuilder.add("\n{@link $L} that configures TLS for this connection.\n\n",
                              TlsParameterGroup.class.getSimpleName());
      }

      tlsJavaDocBuilder.add("@return an optional {@link $L}", TlsParameterGroup.class.getSimpleName()).build();

      connectionProviderClassBuilder
          .addField(tlsField)
          .addMethod(
                     generateOptionalGetter(tlsField, TlsParameterGroup.class, CaseFormat.LOWER_CAMEL)
                         .addAnnotation(Override.class)
                         .addJavadoc(tlsJavaDocBuilder.build())
                         .build());
    }
  }

  private void addOAuth(TypeSpec.Builder connectionProviderClassBuilder) throws TemplatingException {
    if (securityScheme.getSecuritySchemeType().equals(OAUTH2)) {
      AnnotationSpec oAuthAnnotation = null;
      if (getConnectionProviderClass().equals(BaseClientCredentialsConnectionProvider.class)) {
        OAuth2ClientCredentialsScheme clientCredentials = (OAuth2ClientCredentialsScheme) securityScheme;
        oAuthAnnotation = AnnotationSpec
            .builder(ClientCredentials.class)
            .addMember("tokenUrl", "$S", clientCredentials.getAccessTokenUri())
            .addMember("defaultScopes", "$S", clientCredentials.getScopes())
            .build();
      } else if (getConnectionProviderClass().equals(BaseAuthorizationCodeConnectionProvider.class)) {
        OAuth2AuthorizationCodeScheme authorizationCode = (OAuth2AuthorizationCodeScheme) securityScheme;
        oAuthAnnotation = AnnotationSpec
            .builder(AuthorizationCode.class)
            .addMember("accessTokenUrl", "$S", authorizationCode.getAccessTokenUri())
            .addMember("authorizationUrl", "$S", authorizationCode.getAuthorizationUri())
            .addMember("defaultScopes", "$S", authorizationCode.getScopes())
            .build();
      }

      if (oAuthAnnotation != null) {
        connectionProviderClassBuilder.addAnnotation(oAuthAnnotation);
      }
    }
  }

  private CodeBlock generateSdkParameterJavaDoc(SdkParameter sdkParameter) {
    return CodeBlock.builder()
        .add("$L\n", defaultIfEmpty(sdkParameter.getDescription(), sdkParameter.getDisplayName())).build();
  }

  private Class<? extends BaseConnectionProvider> getConnectionProviderClass() throws TemplatingException {
    ConnectorSecurityScheme.SecuritySchemeType securitySchemeType = securityScheme.getSecuritySchemeType();

    if (securitySchemeType.equals(ConnectorSecurityScheme.SecuritySchemeType.BASIC)) {
      return BasicAuthenticationConnectionProvider.class;
    } else if (securitySchemeType.equals(ConnectorSecurityScheme.SecuritySchemeType.DIGEST_AUTHENTICATION)) {
      return DigestConnectionProvider.class;
    } else if (securitySchemeType.equals(ConnectorSecurityScheme.SecuritySchemeType.PASS_THROUGH)) {
      return BaseConnectionProvider.class;
    } else if (securitySchemeType.equals(ConnectorSecurityScheme.SecuritySchemeType.UNSECURED)) {
      return BaseConnectionProvider.class;
    } else if (securitySchemeType.equals(OAUTH2)) {
      if (((OAuth2Scheme) securityScheme).getGrantType().equals(OAuth2Scheme.GrantType.AUTHORIZATION_CODE)) {
        return BaseAuthorizationCodeConnectionProvider.class;
      } else if (((OAuth2Scheme) securityScheme).getGrantType().equals(OAuth2Scheme.GrantType.CLIENT_CREDENTIALS)) {
        return BaseClientCredentialsConnectionProvider.class;
      }
    }

    throw new TemplatingException("Connection Provider not available for " + securitySchemeType);
  }

  private AnnotationSpec generateDisplayNameAnnotation() {
    return AnnotationSpec
        .builder(DisplayName.class)
        .addMember(VALUE_MEMBER, "$S", getConnectionProviderDisplayName())
        .build();
  }

  private AnnotationSpec generateAliasAnnotation() {
    return AnnotationSpec
        .builder(Alias.class)
        .addMember(VALUE_MEMBER, "$S", getConnectionProviderXmlName())
        .build();
  }

  private FieldSpec generateBaseUriField() {
    FieldSpec.Builder baseUriFieldSpec = FieldSpec
        .builder(String.class, "baseUri", Modifier.PRIVATE)
        .addAnnotation(AnnotationSpec.builder(DisplayName.class)
            .addMember(VALUE_MEMBER, "$S", "Base Uri").build())
        .addAnnotation(AnnotationSpec.builder(Summary.class)
            .addMember(VALUE_MEMBER, "$S", connectorModel.getBaseUri().getType().getDescription()).build())
        .addJavadoc(CodeBlock.builder().add(JAVA_DOC_BASE_URI).build());
    if (connectorModel.getBaseUri().isParameterizedBaseUri()) {
      baseUriFieldSpec.addAnnotation(Parameter.class)
          .addAnnotation(buildBaseUriOptionalAnnotation(connectorModel.getBaseUri()));
    } else {
      baseUriFieldSpec.initializer("\"" + connectorModel.getBaseUri().getUri() + "\"");
    }

    return baseUriFieldSpec.build();
  }

  private AnnotationSpec buildBaseUriOptionalAnnotation(BaseUri baseUri) {
    AnnotationSpec.Builder optionalAnnotationBuilder = AnnotationSpec.builder(Optional.class);

    if (isNotBlank(baseUri.getUri())) {
      optionalAnnotationBuilder.addMember("defaultValue", "$S", connectorModel.getBaseUri().getUri());
    }

    return optionalAnnotationBuilder.build();
  }

  private FieldSpec generateTlsField() {
    Class tlsClass = connectorModel.supportsHTTP() ? OptionalTlsParameterGroup.class : MandatoryTlsParameterGroup.class;
    return FieldSpec
        .builder(tlsClass, "tlsConfig", Modifier.PRIVATE)
        .addJavadoc(CodeBlock.builder()
            .add("{@link $L} references to a TLS config element. This will enable HTTPS for this config.\n",
                 tlsClass.getSimpleName())
            .build())
        .addAnnotation(AnnotationSpec.builder(ParameterGroup.class).addMember(NAME_MEMBER, "$S", "tls").build())
        .build();
  }

}
