/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package com.mulesoft.service.oauth.internal.platform;

import static java.lang.Thread.currentThread;
import static java.util.concurrent.CompletableFuture.completedFuture;

import static org.mule.oauth.client.api.state.ResourceOwnerOAuthContext.DEFAULT_RESOURCE_OWNER_ID;
import static org.mule.runtime.core.api.util.ClassUtils.withContextClassLoader;
import static org.mule.runtime.core.api.util.IOUtils.closeQuietly;
import static org.mule.runtime.core.internal.util.ConcurrencyUtils.exceptionallyCompleted;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.NOT_FOUND;

import static org.slf4j.LoggerFactory.getLogger;

import org.mule.oauth.client.api.exception.RequestAuthenticationException;
import org.mule.oauth.client.api.exception.TokenNotFoundException;
import org.mule.oauth.client.api.state.ResourceOwnerOAuthContext;
import org.mule.oauth.client.api.state.ResourceOwnerOAuthContextWithRefreshState;
import org.mule.oauth.client.internal.AbstractOAuthDancer;
import org.mule.oauth.client.internal.state.TokenResponse;
import org.mule.runtime.api.exception.MuleException;
import org.mule.runtime.api.lifecycle.LifecycleException;
import org.mule.runtime.core.api.util.IOUtils;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.runtime.oauth.api.PlatformManagedConnectionDescriptor;
import org.mule.runtime.oauth.api.PlatformManagedOAuthDancer;
import org.mule.runtime.oauth.api.listener.PlatformManagedOAuthStateListener;

import com.mulesoft.service.oauth.internal.platform.config.DefaultPlatformManagedDancerConfig;

import java.io.InputStream;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import java.util.function.Consumer;

import org.slf4j.Logger;

/**
 * Default implementation of {@link PlatformManagedOAuthDancer}.
 *
 * @since 1.0
 */
public class DefaultPlatformManagedDancer extends AbstractOAuthDancer<DefaultPlatformManagedDancerConfig>
    implements PlatformManagedOAuthDancer {

  private static final Logger LOGGER = getLogger(DefaultPlatformManagedDancer.class);
  public static final String REVISION_TOKEN_HEADER = "x-revisionToken";

  private final OCSClient ocsClient;
  private boolean accessTokenRefreshedOnStart = false;
  private String revisionToken;

  public DefaultPlatformManagedDancer(DefaultPlatformManagedDancerConfig config) {
    super(config);

    OCSSettings settings =
        new OCSSettings(config.getPlatformUrl(), config.getTokenUrl(), config.getCredentialConfig().getClientId(),
                        config.getCredentialConfig().getClientSecret(), config.getEncoding(),
                        config.getCredentialConfig().getClientCredentialsLocation(),
                        config.getOrganizationId(),
                        config.getApiVersion());
    this.ocsClient = config.getOcsClientFactory().create(config.getHttpClient(), settings, config.getExpressionEvaluator(),
                                                         config.getOauthService());
    ocsClient.initCoreServicesDancer(config.getLockProvider(), config.getTokensStore(), config.getExpressionEvaluator());
  }

  @Override
  public void start() throws MuleException {
    super.start();
    try {
      accessToken().get();
      accessTokenRefreshedOnStart = true;
    } catch (ExecutionException | CompletionException e) {
      stop();
      throw new LifecycleException(e.getCause(), this);
    } catch (InterruptedException e) {
      stop();
      currentThread().interrupt();
      throw new LifecycleException(e, this);
    }
  }

  @Override
  public CompletableFuture<String> accessToken() {
    // TODO MULE-11858 proactively refresh if the token has already expired based on its 'expiresIn' parameter
    if (!accessTokenRefreshedOnStart) {
      accessTokenRefreshedOnStart = true;
      return doFetchAccessToken();
    }

    final String accessToken = getContext().getAccessToken();
    if (accessToken == null) {
      LOGGER.info("Previously stored for connection URI {} token has been invalidated. Refreshing...", config.getConnectionUri());
      return doFetchAccessToken();
    }

    return completedFuture(accessToken);
  }

  private CompletableFuture<String> doFetchAccessToken() {
    return doRefreshToken(() -> getContext(),
                          ctx -> doAccessTokenRequest((ResourceOwnerOAuthContextWithRefreshState) ctx));
  }

  @Override
  public CompletableFuture<Void> refreshToken() {
    return doRefreshToken(() -> getContext(),
                          ctx -> doRefreshTokenRequest((ResourceOwnerOAuthContextWithRefreshState) ctx));
  }

  @Override
  public CompletableFuture<PlatformManagedConnectionDescriptor> getConnectionDescriptor() {
    try {
      return ocsClient.getConnectionDescriptor(config.getConnectionUri());
    } catch (Throwable t) {
      return exceptionallyCompleted(t);
    }
  }

  private CompletableFuture<String> doAccessTokenRequest(ResourceOwnerOAuthContextWithRefreshState defaultUserState) {
    try {
      return ocsClient.getAccessToken(config.getConnectionUri())
          .thenApply(response -> {
            String url = ocsClient.getAccessTokenUrl(config.getConnectionUri());
            TokenResponse tokenResponse = parseTokenResponseAndUpdateState(response,
                                                                           url,
                                                                           defaultUserState,
                                                                           l -> l.onAccessToken(defaultUserState));

            extractRevisionToken(response);

            return tokenResponse.getAccessToken();
          }).exceptionally(tokenUrlExceptionHandler(defaultUserState));
    } catch (RequestAuthenticationException e) {
      return exceptionallyCompleted(e);
    }
  }

  private void extractRevisionToken(HttpResponse response) {
    String receivedRevision = response.getHeaderValue(REVISION_TOKEN_HEADER);
    if (receivedRevision != null) {
      revisionToken = receivedRevision;
    } else {
      LOGGER.debug("Received a response without a '{}' header.", REVISION_TOKEN_HEADER);
    }
  }

  private CompletableFuture<Void> doRefreshTokenRequest(ResourceOwnerOAuthContextWithRefreshState defaultUserState) {
    try {
      return ocsClient.refreshToken(config.getConnectionUri(), revisionToken)
          .thenApply(response -> {
            String url = ocsClient.getRefreshTokenUrl(config.getConnectionUri());
            parseTokenResponseAndUpdateState(response,
                                             url,
                                             defaultUserState,
                                             l -> l.onTokenRefreshed(defaultUserState));

            extractRevisionToken(response);
            return (Void) null;
          }).exceptionally(tokenUrlExceptionHandler(defaultUserState));
    } catch (RequestAuthenticationException e) {
      return exceptionallyCompleted(e);
    }
  }

  private TokenResponse parseTokenResponseAndUpdateState(HttpResponse response,
                                                         String tokenUrl,
                                                         ResourceOwnerOAuthContextWithRefreshState defaultUserState,
                                                         Consumer<PlatformManagedOAuthStateListener> listenerAction) {
    TokenResponse tokenResponse =
        parseTokenResponse(response, tokenUrl, false);
    withContextClassLoader(DefaultPlatformManagedDancer.class.getClassLoader(), () -> {
      LOGGER.debug("Retrieved access token and expires from token url are: {}, {}",
                   tokenResponse.getAccessToken(), tokenResponse.getExpiresIn());

      defaultUserState.setAccessToken(tokenResponse.getAccessToken());
      defaultUserState.setExpiresIn(tokenResponse.getExpiresIn());
      for (Map.Entry<String, Object> customResponseParameterEntry : tokenResponse.getCustomResponseParameters()
          .entrySet()) {
        defaultUserState.getTokenResponseParameters().put(customResponseParameterEntry.getKey(),
                                                          customResponseParameterEntry.getValue());
      }

      updateOAuthContextAfterTokenResponse(defaultUserState);
      forEachListener(listenerAction);
    });

    return tokenResponse;
  }

  @Override
  protected TokenResponse parseTokenResponse(HttpResponse response, String tokenUrl, boolean retrieveRefreshToken) {
    if (response.getStatusCode() == NOT_FOUND.getStatusCode()) {
      InputStream content = response.getEntity().getContent();
      try {
        // TODO: EE-7175 extract error messages
        throw new CompletionException(new TokenNotFoundException(tokenUrl, response, IOUtils.toString(content)));
      } finally {
        closeQuietly(content);
      }
    }

    return super.parseTokenResponse(response, tokenUrl, retrieveRefreshToken);
  }

  @Override
  public void addListener(PlatformManagedOAuthStateListener listener) {
    doAddListener(listener);
  }

  @Override
  public void removeListener(PlatformManagedOAuthStateListener listener) {
    doRemoveListener(listener);
  }

  @Override
  public void invalidateContext() {
    invalidateContext(DEFAULT_RESOURCE_OWNER_ID);
  }

  @Override
  public ResourceOwnerOAuthContext getContext() {
    return getContextForResourceOwner(DEFAULT_RESOURCE_OWNER_ID);
  }

  private void forEachListener(Consumer<PlatformManagedOAuthStateListener> action) {
    onEachListener(listener -> action.accept((PlatformManagedOAuthStateListener) listener));
  }

  public String getRevisionToken() {
    return revisionToken;
  }
}
