/*
 * (c) 2003-2020 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package com.mulesoft.service.oauth.internal.platform;

import static java.lang.String.format;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static org.mule.runtime.api.i18n.I18nMessageFactory.createStaticMessage;
import static org.mule.runtime.api.metadata.DataType.JSON_STRING;
import static org.mule.runtime.api.metadata.DataType.STRING;
import static org.mule.runtime.api.metadata.DataType.fromType;
import static org.mule.runtime.api.util.Preconditions.checkState;
import static org.mule.runtime.core.api.lifecycle.LifecycleUtils.initialiseIfNeeded;
import static org.mule.runtime.core.api.lifecycle.LifecycleUtils.startIfNeeded;
import static org.mule.runtime.core.api.util.IOUtils.closeQuietly;
import static org.mule.runtime.core.api.util.StringUtils.isBlank;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.UNAUTHORIZED;
import static org.mule.runtime.http.api.HttpConstants.Method.GET;
import static org.mule.runtime.http.api.HttpConstants.Method.POST;
import static org.mule.runtime.http.api.HttpHeaders.Names.AUTHORIZATION;
import static org.mule.runtime.oauth.internal.AbstractOAuthDancer.TOKEN_REQUEST_TIMEOUT_MILLIS;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.runtime.api.el.BindingContext;
import org.mule.runtime.api.el.ExpressionLanguage;
import org.mule.runtime.api.el.MuleExpressionLanguage;
import org.mule.runtime.api.exception.MuleException;
import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.api.lifecycle.Stoppable;
import org.mule.runtime.api.lock.LockFactory;
import org.mule.runtime.api.metadata.DataType;
import org.mule.runtime.api.metadata.TypedValue;
import org.mule.runtime.api.util.MultiMap;
import org.mule.runtime.api.util.MultiMap.StringMultiMap;
import org.mule.runtime.http.api.HttpConstants.Method;
import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.client.HttpRequestOptions;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.request.HttpRequestBuilder;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.runtime.oauth.api.ClientCredentialsOAuthDancer;
import org.mule.runtime.oauth.api.OAuthService;
import org.mule.runtime.oauth.api.PlatformManagedConnectionDescriptor;
import org.mule.runtime.oauth.api.exception.RequestAuthenticationException;
import org.mule.runtime.oauth.api.state.ResourceOwnerOAuthContext;

import java.io.InputStream;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import org.slf4j.Logger;

/**
 * Client for the OCS API.
 * <p>
 * This client uses the Client Credentials grant type to OAuth into the Anypoint Platform in order to access the OCS API.
 *
 * @since 1.0
 */
public class OCSClient implements Stoppable {

  private static final Logger LOGGER = getLogger(OCSClient.class);
  private static final DataType CONNECTION_DESCRIPTOR_DATA_TYPE = fromType(DefaultPlatformManagedConnectionDescriptor.class);
  private static final String DESCRIPTOR_MAPPING_EXPRESSION =
      "#[{ id : payload.id, displayName: payload.displayName, uri : payload.uri, parameters:  payload.parameters}]";
  private static final HttpRequestOptions OCS_REQUEST_OPTIONS = HttpRequestOptions.builder()
      .responseTimeout(TOKEN_REQUEST_TIMEOUT_MILLIS)
      .build();
  private static final String DEFAULT_API_VERSION = "v1";
  private static final String API_PREFIX_FORMAT = "/api/%s";

  public static final String REVISION_TOKEN_QUERY_PARAM = "rev";

  public static class OCSClientFactory {

    public OCSClient create(HttpClient client,
                            OCSSettings settings,
                            ExpressionLanguage expressionLanguage,
                            OAuthService oauthService) {
      return new OCSClient(client, settings, expressionLanguage, oauthService);
    }

    private OCSClientFactory() {}
  }

  public static OCSClientFactory getFactory() {
    return new OCSClientFactory();
  }

  private final HttpClient httpClient;
  private final OCSSettings settings;
  private final OAuthService oauthService;
  private final ExpressionLanguage expressionLanguage;
  private final String apiPrefix;
  private ClientCredentialsOAuthDancer dancer;

  private OCSClient(HttpClient httpClient,
                    OCSSettings settings,
                    ExpressionLanguage expressionLanguage,
                    OAuthService oauthService) {
    this.httpClient = httpClient;
    this.settings = settings;
    this.expressionLanguage = expressionLanguage;
    this.oauthService = oauthService;
    this.apiPrefix = getApiPrefix(settings);
  }

  private String getApiPrefix(OCSSettings settings) {
    String apiVersion = settings.getApiVersion();
    return format(API_PREFIX_FORMAT, apiVersion == null ? DEFAULT_API_VERSION : apiVersion);
  }

  @Override
  public void stop() throws MuleException {
    httpClient.stop();
  }

  /**
   * Gets the access token for the given {@code connectionUri}
   *
   * @param connectionUri the uri of the connection which token we want
   * @return a {@link CompletableFuture} with the obtained {@link HttpResponse}
   * @throws RequestAuthenticationException in case of failure to OAuth into the Anypoint Platform
   */
  public CompletableFuture<HttpResponse> getAccessToken(String connectionUri) throws RequestAuthenticationException {
    LOGGER.info("Fetching access token for connection {}", connectionUri);

    return getCoreServicesAccessToken().thenCompose(csToken -> {
      String uri = getAccessTokenUrl(connectionUri);
      return ocsRequest(csToken, uri, GET);
    });
  }

  /**
   * Refreshes the token for the given {@code connectionUri}
   *
   * @param connectionUri the uri of the connection which token we want to refresh
   * @param revisionToken the revision token of the token being refreshed
   * @return a {@link CompletableFuture} with the obtained {@link HttpResponse}
   * @throws RequestAuthenticationException in case of failure to OAuth into the Anypoint Platform
   */
  public CompletableFuture<HttpResponse> refreshToken(String connectionUri, String revisionToken)
      throws RequestAuthenticationException {
    LOGGER.info("Refreshing token for connection {} using lastUpdatedTimestamp {}", connectionUri, revisionToken);

    return getCoreServicesAccessToken().thenCompose(csToken -> {
      String uri = getRefreshTokenUrl(connectionUri);
      StringMultiMap queryParams = new StringMultiMap();
      if (revisionToken != null) {
        queryParams.put(REVISION_TOKEN_QUERY_PARAM, revisionToken);
      }
      return ocsRequest(csToken, uri, POST, queryParams);
    });
  }

  /**
   * Obtains a {@link PlatformManagedConnectionDescriptor} describing the connection of the given {@code connectionUri}
   *
   * @param connectionUri the uri of the connection which informatiton we seek
   * @return a {@link CompletableFuture} with the obtained {@link RequestAuthenticationException}
   * @throws RequestAuthenticationException in case of failure to OAuth into the Anypoint Platform
   */
  public CompletableFuture<PlatformManagedConnectionDescriptor> getConnectionDescriptor(String connectionUri)
      throws RequestAuthenticationException {

    LOGGER.debug("Fetching connection descriptor for connection {}", connectionUri);

    return getCoreServicesAccessToken().thenCompose(csToken -> {
      String uri = getConnectionDescriptorUrl(connectionUri);
      return ocsRequest(csToken, uri, GET).thenApply(response -> {

        InputStream responseBody = response.getEntity().getContent();
        try {
          BindingContext bindingContext = BindingContext.builder()
              .addBinding("payload", new TypedValue(responseBody, JSON_STRING))
              .build();

          TypedValue<PlatformManagedConnectionDescriptor> descriptor =
              (TypedValue<PlatformManagedConnectionDescriptor>) expressionLanguage.evaluate(DESCRIPTOR_MAPPING_EXPRESSION,
                                                                                            CONNECTION_DESCRIPTOR_DATA_TYPE,
                                                                                            bindingContext);

          return new ImmutablePlatformManagedConnectionDescriptor(descriptor.getValue());
        } finally {
          closeQuietly(responseBody);
        }
      });
    });
  }

  private CompletableFuture<HttpResponse> ocsRequest(String csToken, String uri, Method method) {
    return ocsRequest(csToken, uri, method, null, true);
  }

  private CompletableFuture<HttpResponse> ocsRequest(String csToken,
                                                     String uri,
                                                     Method method,
                                                     MultiMap<String, String> queryParams) {
    return ocsRequest(csToken, uri, method, queryParams, true);
  }

  private CompletableFuture<HttpResponse> ocsRequest(String csToken,
                                                     String uri,
                                                     Method method,
                                                     MultiMap<String, String> queryParams,
                                                     boolean refreshOnUnauthorized) {
    return httpClient.sendAsync(coreServicesRequest(uri, csToken, method, queryParams), OCS_REQUEST_OPTIONS)
        .thenCompose(originalResponse -> {
          final int statusCode = originalResponse.getStatusCode();

          if (isStatusCodeUnauthorized(statusCode)) {
            if (refreshOnUnauthorized) {
              LOGGER.info(
                          "Anypoint Platform access token expired. Request to {} returned status code {}. Attempting to refresh access token",
                          uri, statusCode);
              return refreshCoreServicesAccessToken()
                  .thenCompose(refreshedToken -> ocsRequest(refreshedToken, uri, method, queryParams, false));
            } else {
              LOGGER.info("Refresh token of Anypoint Platform access token failed with status code {}. Will not retry.",
                          statusCode);
            }
          } else if (!isStatusCodeSuccessful(statusCode)) {
            handleOcsError(originalResponse, statusCode, uri);
          }
          return completedFuture(originalResponse);
        });
  }

  String getAccessTokenUrl(String connectionUri) {
    return settings.getPlatformUrl() + apiPrefix + "/organizations/" + settings.getOrganizationId()
        + "/connections/" + connectionUri
        + "/token";
  }

  String getRefreshTokenUrl(String connectionUri) {
    return settings.getPlatformUrl() + apiPrefix + "/organizations/" + settings.getOrganizationId()
        + "/connections/" + connectionUri
        + "/token";
  }

  String getConnectionDescriptorUrl(String connectionUri) {
    return settings.getPlatformUrl() + apiPrefix + "/organizations/" + settings.getOrganizationId()
        + "/connections/" + connectionUri;
  }

  private HttpRequest coreServicesRequest(String uri, String csAccessToken, Method method, MultiMap<String, String> queryParams) {
    HttpRequestBuilder builder = HttpRequest.builder()
        .uri(uri)
        .method(method)
        .addHeader(AUTHORIZATION, "bearer " + csAccessToken);

    if (queryParams != null) {
      builder.queryParams(queryParams);
    }

    return builder.build();
  }

  private CompletableFuture<String> getCoreServicesAccessToken() throws RequestAuthenticationException {
    assertDancerSet();
    return dancer.accessToken();
  }

  private CompletableFuture<String> refreshCoreServicesAccessToken() {
    assertDancerSet();

    return dancer.refreshToken().thenApply(v -> dancer.getContext().getAccessToken());
  }

  void initCoreServicesDancer(LockFactory lockFactory,
                              Map<String, ResourceOwnerOAuthContext> tokenStore,
                              MuleExpressionLanguage expressionLanguage) {
    if (dancer != null) {
      return;
    }

    synchronized (this) {
      if (dancer != null) {
        return;
      }
      checkState(oauthService != null, "oauthService has not been set");
      try {
        dancer = oauthService.clientCredentialsGrantTypeDancerBuilder(lockFactory, tokenStore, expressionLanguage)
            .name("OCS@" + settings.getPlatformUrl())
            .encoding(settings.getEncoding())
            .clientCredentials(settings.getClientId(), settings.getClientSecret())
            .tokenUrl(settings.getTokenUrl())
            .withClientCredentialsIn(settings.getClientCredentialsLocation())
            .build();

        initialiseIfNeeded(dancer);
        startIfNeeded(dancer);
      } catch (Exception e) {
        throw new MuleRuntimeException(createStaticMessage("Could not obtain access token from Anypoint Platform"), e);
      }
    }
  }

  private boolean isStatusCodeSuccessful(int statusCode) {
    return !(statusCode < 200 || statusCode >= 300);
  }

  private boolean isStatusCodeUnauthorized(int statusCode) {
    return statusCode == UNAUTHORIZED.getStatusCode();
  }

  private void handleOcsError(HttpResponse errorResponse, int statusCode, String uri) {
    String errorMessage = null;
    InputStream content = errorResponse.getEntity().getContent();
    try {
      BindingContext bindingContext =
          BindingContext.builder().addBinding("payload", new TypedValue(content, JSON_STRING)).build();
      errorMessage = (String) expressionLanguage.evaluate("#[payload.message]", STRING, bindingContext).getValue();
    } catch (Exception e) {
      if (LOGGER.isErrorEnabled()) {
        LOGGER.error(format("Failed to retrieve error message from request to %s with status code %d .", uri, statusCode), e);
      }
    } finally {
      closeQuietly(content);
    }
    if (isBlank(errorMessage)) {
      throw new MuleRuntimeException(createStaticMessage(
                                                         format("Got status code %d when trying when making a request to : %s",
                                                                statusCode, uri)));
    } else {
      throw new MuleRuntimeException(createStaticMessage(
                                                         format("Got status code %d when trying when making a request to : %s . Message : %s",
                                                                statusCode, uri, errorMessage)));
    }
  }

  private void assertDancerSet() {
    checkState(dancer != null, "Core Services Dancer not yet initialized");
  }
}
