/*
 * (c) 2003-2020 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package com.mulesoft.service.oauth.internal.platform;

import static java.lang.Thread.currentThread;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static org.mule.runtime.core.api.util.ClassUtils.withContextClassLoader;
import static org.mule.runtime.core.api.util.IOUtils.closeQuietly;
import static org.mule.runtime.core.internal.util.ConcurrencyUtils.exceptionallyCompleted;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.NOT_FOUND;
import static org.mule.runtime.oauth.api.state.ResourceOwnerOAuthContext.DEFAULT_RESOURCE_OWNER_ID;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.runtime.api.el.MuleExpressionLanguage;
import org.mule.runtime.api.exception.MuleException;
import org.mule.runtime.api.lifecycle.LifecycleException;
import org.mule.runtime.api.lock.LockFactory;
import org.mule.runtime.api.scheduler.SchedulerService;
import org.mule.runtime.core.api.util.IOUtils;
import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.runtime.oauth.api.OAuthService;
import org.mule.runtime.oauth.api.PlatformManagedConnectionDescriptor;
import org.mule.runtime.oauth.api.PlatformManagedOAuthDancer;
import org.mule.runtime.oauth.api.builder.ClientCredentialsLocation;
import org.mule.runtime.oauth.api.exception.RequestAuthenticationException;
import org.mule.runtime.oauth.api.exception.TokenNotFoundException;
import org.mule.runtime.oauth.api.listener.OAuthStateListener;
import org.mule.runtime.oauth.api.listener.PlatformManagedOAuthStateListener;
import org.mule.runtime.oauth.api.state.ResourceOwnerOAuthContext;
import org.mule.runtime.oauth.api.state.ResourceOwnerOAuthContextWithRefreshState;
import org.mule.service.oauth.internal.AbstractOAuthDancer;
import org.mule.service.oauth.internal.state.TokenResponse;

import com.mulesoft.service.oauth.internal.platform.OCSClient.OCSClientFactory;

import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import java.util.function.Consumer;
import java.util.function.Function;

import org.slf4j.Logger;

/**
 * Default implementation of {@link PlatformManagedOAuthDancer}.
 *
 * @since 1.0
 */
public class DefaultPlatformManagedDancer extends AbstractOAuthDancer implements PlatformManagedOAuthDancer {

  private static final Logger LOGGER = getLogger(DefaultPlatformManagedDancer.class);

  private final OCSClient ocsClient;
  private final String connectionUri;
  private boolean accessTokenRefreshedOnStart = false;

  public DefaultPlatformManagedDancer(String name,
                                      String connectionUri,
                                      String organizationId,
                                      String platformUrl,
                                      OCSClientFactory ocsClientFactory,
                                      OAuthService oauthService,
                                      String clientId,
                                      String clientSecret,
                                      String tokenUrl,
                                      String scopes,
                                      ClientCredentialsLocation clientCredentialsLocation,
                                      Charset encoding,
                                      String responseAccessTokenExpr,
                                      String responseRefreshTokenExpr,
                                      String responseExpiresInExpr,
                                      Map<String, String> customParametersExtractorsExprs,
                                      Function<String, String> resourceOwnerIdTransformer,
                                      SchedulerService schedulerService,
                                      LockFactory lockProvider,
                                      Map<String, ResourceOwnerOAuthContext> tokensStore,
                                      HttpClient httpClient,
                                      MuleExpressionLanguage expressionEvaluator,
                                      List<? extends OAuthStateListener> listeners) {
    super(name, clientId, clientSecret, tokenUrl, encoding, scopes, clientCredentialsLocation, responseAccessTokenExpr,
          responseRefreshTokenExpr, responseExpiresInExpr, customParametersExtractorsExprs, resourceOwnerIdTransformer,
          schedulerService, lockProvider, tokensStore, httpClient, expressionEvaluator, listeners);

    OCSSettings settings =
        new OCSSettings(platformUrl, tokenUrl, clientId, clientSecret, encoding, clientCredentialsLocation, organizationId);
    this.ocsClient = ocsClientFactory.create(httpClient, settings, expressionEvaluator, oauthService);
    ocsClient.initCoreServicesDancer(lockProvider, tokensStore, expressionEvaluator);
    this.connectionUri = connectionUri;
  }

  @Override
  public void start() throws MuleException {
    super.start();
    try {
      accessToken().get();
      accessTokenRefreshedOnStart = true;
    } catch (ExecutionException | CompletionException e) {
      stop();
      throw new LifecycleException(e.getCause(), this);
    } catch (InterruptedException e) {
      stop();
      currentThread().interrupt();
      throw new LifecycleException(e, this);
    }
  }

  @Override
  public CompletableFuture<String> accessToken() {
    // TODO MULE-11858 proactively refresh if the token has already expired based on its 'expiresIn' parameter
    if (!accessTokenRefreshedOnStart) {
      accessTokenRefreshedOnStart = true;
      return doFetchAccessToken();
    }

    final String accessToken = getContext().getAccessToken();
    if (accessToken == null) {
      LOGGER.info("Previously stored for connection URI {} token has been invalidated. Refreshing...", connectionUri);
      return doFetchAccessToken();
    }

    return completedFuture(accessToken);
  }

  private CompletableFuture<String> doFetchAccessToken() {
    return doRefreshToken(() -> getContext(),
                          ctx -> doAccessTokenRequest((ResourceOwnerOAuthContextWithRefreshState) ctx));
  }

  @Override
  public CompletableFuture<Void> refreshToken() {
    return doRefreshToken(() -> getContext(),
                          ctx -> doRefreshTokenRequest((ResourceOwnerOAuthContextWithRefreshState) ctx));
  }

  @Override
  public CompletableFuture<PlatformManagedConnectionDescriptor> getConnectionDescriptor() {
    try {
      return ocsClient.getConnectionDescriptor(connectionUri);
    } catch (Throwable t) {
      return exceptionallyCompleted(t);
    }
  }

  private CompletableFuture<String> doAccessTokenRequest(ResourceOwnerOAuthContextWithRefreshState defaultUserState) {
    try {
      return ocsClient.getAccessToken(connectionUri)
          .thenApply(response -> {
            String url = ocsClient.getAccessTokenUrl(connectionUri);
            TokenResponse tokenResponse = parseTokenResponseAndUpdateState(response,
                                                                           url,
                                                                           defaultUserState,
                                                                           l -> l.onAccessToken(defaultUserState));

            return tokenResponse.getAccessToken();
          }).exceptionally(tokenUrlExceptionHandler(defaultUserState));
    } catch (RequestAuthenticationException e) {
      return exceptionallyCompleted(e);
    }
  }

  private CompletableFuture<Void> doRefreshTokenRequest(ResourceOwnerOAuthContextWithRefreshState defaultUserState) {
    try {
      return ocsClient.refreshToken(connectionUri)
          .thenApply(response -> {
            String url = ocsClient.getRefreshTokenUrl(connectionUri);
            parseTokenResponseAndUpdateState(response,
                                             url,
                                             defaultUserState,
                                             l -> l.onTokenRefreshed(defaultUserState));
            return (Void) null;
          }).exceptionally(tokenUrlExceptionHandler(defaultUserState));
    } catch (RequestAuthenticationException e) {
      return exceptionallyCompleted(e);
    }
  }

  private TokenResponse parseTokenResponseAndUpdateState(HttpResponse response,
                                                         String tokenUrl,
                                                         ResourceOwnerOAuthContextWithRefreshState defaultUserState,
                                                         Consumer<PlatformManagedOAuthStateListener> listenerAction) {
    TokenResponse tokenResponse =
        parseTokenResponse(response, tokenUrl, false);
    withContextClassLoader(DefaultPlatformManagedDancer.class.getClassLoader(), () -> {
      LOGGER.debug("Retrieved access token and expires from token url are: {}, {}",
                   tokenResponse.getAccessToken(), tokenResponse.getExpiresIn());

      defaultUserState.setAccessToken(tokenResponse.getAccessToken());
      defaultUserState.setExpiresIn(tokenResponse.getExpiresIn());
      for (Map.Entry<String, Object> customResponseParameterEntry : tokenResponse.getCustomResponseParameters()
          .entrySet()) {
        defaultUserState.getTokenResponseParameters().put(customResponseParameterEntry.getKey(),
                                                          customResponseParameterEntry.getValue());
      }

      updateOAuthContextAfterTokenResponse(defaultUserState);
      forEachListener(listenerAction);
    });

    return tokenResponse;
  }

  @Override
  protected TokenResponse parseTokenResponse(HttpResponse response, String tokenUrl, boolean retrieveRefreshToken) {
    if (response.getStatusCode() == NOT_FOUND.getStatusCode()) {
      InputStream content = response.getEntity().getContent();
      try {
        // TODO: EE-7175 extract error messages
        throw new CompletionException(new TokenNotFoundException(tokenUrl, response, IOUtils.toString(content)));
      } finally {
        closeQuietly(content);
      }
    }

    return super.parseTokenResponse(response, tokenUrl, retrieveRefreshToken);
  }

  @Override
  public void addListener(PlatformManagedOAuthStateListener listener) {
    doAddListener(listener);
  }

  @Override
  public void removeListener(PlatformManagedOAuthStateListener listener) {
    doRemoveListener(listener);
  }

  @Override
  public void invalidateContext() {
    invalidateContext(DEFAULT_RESOURCE_OWNER_ID);
  }

  @Override
  public ResourceOwnerOAuthContext getContext() {
    return getContextForResourceOwner(DEFAULT_RESOURCE_OWNER_ID);
  }

  private void forEachListener(Consumer<PlatformManagedOAuthStateListener> action) {
    onEachListener(listener -> action.accept((PlatformManagedOAuthStateListener) listener));
  }
}
