/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package com.mulesoft.test.service.oauth.internal.platform;

import static com.mulesoft.service.oauth.internal.platform.DefaultPlatformManagedDancer.REVISION_TOKEN_HEADER;
import static java.lang.Thread.currentThread;
import static java.lang.Thread.sleep;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.Executors.newFixedThreadPool;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.sameInstance;
import static org.junit.Assert.assertThat;
import static org.mockito.Answers.RETURNS_DEEP_STUBS;
import static org.mockito.ArgumentCaptor.forClass;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.mockito.MockitoAnnotations.initMocks;

import static org.mule.oauth.client.api.builder.ClientCredentialsLocation.BODY;
import static org.mule.oauth.client.api.state.DancerState.NO_TOKEN;
import static org.mule.runtime.api.metadata.MediaType.APPLICATION_JSON;
import static org.mule.runtime.core.api.lifecycle.LifecycleUtils.stopIfNeeded;
import static org.mule.runtime.core.internal.util.FunctionalUtils.safely;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.INTERNAL_SERVER_ERROR;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.NOT_FOUND;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.OK;
import static org.mule.runtime.http.api.HttpHeaders.Names.CONTENT_TYPE;

import org.mule.oauth.client.api.builder.OAuthDancerBuilder;
import org.mule.oauth.client.api.exception.TokenNotFoundException;
import org.mule.oauth.client.api.exception.TokenUrlResponseException;
import org.mule.oauth.client.api.http.HttpClientFactory;
import org.mule.oauth.client.api.state.ResourceOwnerOAuthContext;
import org.mule.oauth.client.api.state.ResourceOwnerOAuthContextWithRefreshState;
import org.mule.runtime.api.el.MuleExpressionLanguage;
import org.mule.runtime.api.exception.MuleException;
import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.api.lifecycle.LifecycleException;
import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.domain.entity.InputStreamHttpEntity;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.runtime.oauth.api.PlatformManagedConnectionDescriptor;
import org.mule.runtime.oauth.api.PlatformManagedOAuthDancer;
import org.mule.runtime.oauth.api.builder.OAuthPlatformManagedDancerBuilder;
import org.mule.runtime.oauth.api.listener.PlatformManagedOAuthStateListener;
import org.mule.runtime.test.oauth.AbstractOAuthTestCase;
import org.mule.service.oauth.internal.DefaultOAuthService;
import org.mule.tck.SimpleUnitTestSupportSchedulerService;

import com.mulesoft.service.oauth.internal.EEOAuthService;
import com.mulesoft.service.oauth.internal.platform.DefaultOAuthPlatformManagedDancerBuilder;
import com.mulesoft.service.oauth.internal.platform.DefaultPlatformManagedDancer;
import com.mulesoft.service.oauth.internal.platform.OCSClient;
import com.mulesoft.service.oauth.internal.platform.OCSClient.OCSClientFactory;
import com.mulesoft.service.oauth.internal.platform.OCSSettings;

import java.io.ByteArrayInputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import javax.inject.Inject;

import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.stubbing.Answer;

public class PlatformManagedDancerTestCase extends AbstractOAuthTestCase {

  private static final String CLIENT_ID = "clientId";
  private static final String CLIENT_SECRET = "clientSecret";
  private static final String SERVICE_URL = "https://cs.anypoint.com/ocs";
  private static final String AUTH_URL = "https://cs.anypoint.com/token";
  private static final String ORG_ID = "fooOrg";
  private static final String CONNECTION_URI = "ocs:123122-123123123121-934834394/instagram/soylamorza-2fun4o3wf";
  private static final String ACCESS_TOKEN = "myToken";
  private static final String REFRESHED_ACCESS_TOKEN = "myToken";
  private static final String REFRESH_TOKEN = "refreshMe";
  private static final String REFRESHED_TOKEN_NEW_REVISION = "2";

  private static final String ANOTHER_API_VERSION = "v2";

  private static final String ACCESS_TOKEN_RESPONSE = "{\n"
      + "\"access_token\": \"" + ACCESS_TOKEN + "\",\n"
      + "\"refresh_token\": \"" + REFRESH_TOKEN + "\"\n"
      + "}";

  private static final String REFRESH_TOKEN_RESPONSE = "{\n"
      + "\"access_token\": \"" + REFRESHED_ACCESS_TOKEN + "\",\n"
      + "\"refresh_token\": \"" + REFRESH_TOKEN + "\"\n"
      + "}";


  @Mock(answer = RETURNS_DEEP_STUBS)
  private OCSClient ocsClient;

  @Mock(answer = RETURNS_DEEP_STUBS)
  private HttpResponse accessTokenHttpResponse;

  @Mock(answer = RETURNS_DEEP_STUBS)
  private HttpResponse refreshTokenHttpResponse;

  @Mock(answer = RETURNS_DEEP_STUBS)
  private HttpResponse connectionDescriptorResponse;

  @Mock
  private OCSClientFactory ocsClientFactory;

  @Mock
  private PlatformManagedOAuthStateListener listener;

  @Mock
  private HttpClientFactory httpClientFactory;

  @Mock
  private HttpClient httpClient;

  @Inject
  private MuleExpressionLanguage expressionLanguage;

  @Rule
  public ExpectedException expected = ExpectedException.none();

  private Map<String, ?> tokensStore = new HashMap<>();
  private OAuthPlatformManagedDancerBuilder builder;
  private DefaultPlatformManagedDancer dancer;
  private SimpleUnitTestSupportSchedulerService schedulerService = new SimpleUnitTestSupportSchedulerService();
  private DefaultOAuthService service;
  private String revisionToken = "1";

  @Override
  public void setupServices() throws Exception {
    initMocks(this);
    super.setupServices();
    service = new EEOAuthService(httpService, schedulerService);

    expressionLanguage = spy(expressionLanguage);
    mockResponseBody(accessTokenHttpResponse, ACCESS_TOKEN_RESPONSE, revisionToken);
    mockResponseBody(refreshTokenHttpResponse, REFRESH_TOKEN_RESPONSE, REFRESHED_TOKEN_NEW_REVISION);
    when(ocsClient.getAccessToken(CONNECTION_URI)).thenAnswer(delayedResponse(accessTokenHttpResponse));
    when(ocsClient.refreshToken(anyString(), anyString())).thenAnswer(delayedResponse(refreshTokenHttpResponse));
    when(ocsClient.getConnectionDescriptor(CONNECTION_URI)).thenAnswer(delayedResponse(connectionDescriptorResponse));

    when(ocsClientFactory.create(same(httpClient), any(), same(expressionLanguage), same(service))).thenReturn(ocsClient);
    when(httpClientFactory.create(any(), any())).thenReturn(httpClient);
    builder = basePlatformOAuthDancerBuilder(tokensStore);
    dancer = createDancer(builder);
  }

  @Override
  public void teardownServices() {
    super.teardownServices();
    safely(schedulerService::stop);
    verify(expressionLanguage, never()).evaluate(eq("#[payload.refresh_token]"), any(), any());
    if (dancer != null) {
      safely(() -> stopIfNeeded(dancer));
      verify(httpClient).stop();
    }
  }

  @Test
  public void clientPropertyConfigured() {
    ArgumentCaptor<OCSSettings> settingsCaptor = forClass(OCSSettings.class);
    verify(ocsClientFactory).create(same(httpClient), settingsCaptor.capture(), same(expressionLanguage), same(service));

    OCSSettings settings = settingsCaptor.getValue();
    assertThat(settings.getPlatformUrl(), is(SERVICE_URL));
    assertThat(settings.getOrganizationId(), is(ORG_ID));
    assertThat(settings.getClientCredentialsLocation(), is(BODY));
    assertThat(settings.getTokenUrl(), is(AUTH_URL));
    assertThat(settings.getEncoding(), is(UTF_8));
    assertThat(settings.getClientId(), is(CLIENT_ID));
    assertThat(settings.getClientSecret(), is(CLIENT_SECRET));
    assertThat(settings.getApiVersion(), is(nullValue()));
  }

  @Test
  public void clientPropertyConfiguredWithAnApiVersion() throws Exception {
    reset(ocsClient);
    builder = basePlatformOAuthDancerBuilder(tokensStore, ANOTHER_API_VERSION);
    dancer = createDancer(builder);

    ArgumentCaptor<OCSSettings> settingsCaptor = forClass(OCSSettings.class);
    verify(ocsClientFactory, times(2)).create(same(httpClient), settingsCaptor.capture(), same(expressionLanguage),
                                              same(service));

    OCSSettings settings = settingsCaptor.getValue();
    assertThat(settings.getPlatformUrl(), is(SERVICE_URL));
    assertThat(settings.getOrganizationId(), is(ORG_ID));
    assertThat(settings.getClientCredentialsLocation(), is(BODY));
    assertThat(settings.getTokenUrl(), is(AUTH_URL));
    assertThat(settings.getEncoding(), is(UTF_8));
    assertThat(settings.getClientId(), is(CLIENT_ID));
    assertThat(settings.getClientSecret(), is(CLIENT_SECRET));
    assertThat(settings.getApiVersion(), is(ANOTHER_API_VERSION));
  }

  @Test
  public void accessToken() throws Exception {
    String accessToken = dancer.accessToken().get();
    assertThat(accessToken, equalTo(ACCESS_TOKEN));

    verify(listener).onAccessToken(any());
    assertThat(dancer.getRevisionToken(), equalTo(revisionToken));
  }

  @Test
  public void platformReturnsErrorOnStart() throws Exception {
    when(accessTokenHttpResponse.getStatusCode()).thenReturn(INTERNAL_SERVER_ERROR.getStatusCode());
    expected.expect(LifecycleException.class);
    expected.expectCause(is(instanceOf(TokenUrlResponseException.class)));

    clearInvocations(listener);

    try {
      dancer = createDancer(builder);
    } finally {
      verify(listener, never()).onAccessToken(any());
    }
  }

  @Test
  public void platformReturnsNotFoundOnStart() throws Exception {
    when(accessTokenHttpResponse.getStatusCode()).thenReturn(NOT_FOUND.getStatusCode());
    expected.expect(LifecycleException.class);
    expected.expectCause(is(instanceOf(TokenNotFoundException.class)));

    clearInvocations(listener);

    try {
      dancer = createDancer(builder);
    } finally {
      verify(listener, never()).onAccessToken(any());
    }
  }

  @Test
  public void refreshToken() throws Exception {
    dancer.refreshToken().get();
    String accessToken = dancer.accessToken().get();
    assertThat(accessToken, equalTo(REFRESHED_ACCESS_TOKEN));

    assertRefreshTokenListened();
    assertThat(dancer.getRevisionToken(), equalTo(REFRESHED_TOKEN_NEW_REVISION));
  }

  @Test
  public void sendCorrectRevisionAfterRefresh() throws Exception {
    refreshToken();

    final String thirdRevision = "3";
    mockResponseBody(refreshTokenHttpResponse, REFRESH_TOKEN_RESPONSE, thirdRevision);

    dancer.refreshToken().get();
    ArgumentCaptor<String> revisionCaptor = forClass(String.class);
    verify(ocsClient, times(2)).refreshToken(anyString(), revisionCaptor.capture());
    assertThat(revisionCaptor.getValue(), equalTo(REFRESHED_TOKEN_NEW_REVISION));
    assertThat(dancer.getRevisionToken(), equalTo(thirdRevision));
  }

  @Test
  public void refreshTokenDoesntProvideNewRevision() throws Exception {
    mockResponseBody(refreshTokenHttpResponse, REFRESH_TOKEN_RESPONSE, null);
    dancer.refreshToken().get();
    String accessToken = dancer.accessToken().get();
    assertThat(accessToken, equalTo(REFRESHED_ACCESS_TOKEN));

    assertRefreshTokenListened();
    assertThat(dancer.getRevisionToken(), equalTo(revisionToken));
  }

  @Test
  public void platformReturnsErrorOnRefreshToken() throws Exception {
    when(refreshTokenHttpResponse.getStatusCode()).thenReturn(INTERNAL_SERVER_ERROR.getStatusCode());
    expected.expect(ExecutionException.class);
    expected.expectCause(is(instanceOf(TokenUrlResponseException.class)));

    try {
      dancer.refreshToken().get();
    } finally {
      verify(listener, never()).onTokenRefreshed(any());
    }
  }

  @Test
  public void platformReturnsNotFoundOnRefreshToken() throws Exception {
    when(refreshTokenHttpResponse.getStatusCode()).thenReturn(NOT_FOUND.getStatusCode());
    expected.expect(ExecutionException.class);
    expected.expectCause(is(instanceOf(TokenNotFoundException.class)));

    try {
      dancer.refreshToken().get();
    } finally {
      verify(listener, never()).onTokenRefreshed(any());
    }
  }

  @Test
  public void refetchAccessTokenAfterInvalidate() throws Exception {
    assertThat(dancer.accessToken().get(), not(nullValue()));
    dancer.invalidateContext();

    clearInvocations(listener);
    assertThat(dancer.accessToken().get(), equalTo(ACCESS_TOKEN));

    assertAccessTokenListened(ACCESS_TOKEN);
    verify(ocsClient, times(2)).getAccessToken(CONNECTION_URI);
  }

  @Test
  public void refreshTokenOnceAtATime() throws Exception {
    ExecutorService executor = newFixedThreadPool(2);
    tokensStore.clear();
    try {
      List<Future<?>> futures = new ArrayList<>();
      for (int i = 0; i < 2; ++i) {
        futures.add(executor.submit(() -> {
          try {
            dancer.refreshToken().get();
          } catch (Exception e) {
            throw new MuleRuntimeException(e);
          }
        }));
      }

      for (Future<?> future : futures) {
        future.get(RECEIVE_TIMEOUT * 10, MILLISECONDS);
      }

      // 2 refreshes, only one actual outbound request
      verify(ocsClient).refreshToken(CONNECTION_URI, revisionToken);
    } finally {
      executor.shutdownNow();
    }
  }

  @Test
  public void refreshTokenOnceAtATimeSequential() throws Exception {
    final CompletableFuture<Void> refreshToken1 = dancer.refreshToken();
    final CompletableFuture<Void> refreshToken2 = dancer.refreshToken();
    refreshToken1.get();
    refreshToken2.get();

    // 2 refreshes, only one actual outbound request
    verify(ocsClient).refreshToken(CONNECTION_URI, revisionToken);
  }

  @Test
  public void exceptionOnTokenRequest() throws Exception {
    final IllegalStateException thrown = new IllegalStateException();
    when(ocsClient.getAccessToken(CONNECTION_URI)).thenThrow(thrown);

    final Map<String, ResourceOwnerOAuthContextWithRefreshState> tokensStore = new HashMap<>();
    final OAuthPlatformManagedDancerBuilder builder = basePlatformOAuthDancerBuilder(tokensStore);

    expected.expect(sameInstance(thrown));
    try {
      createDancer(builder);
    } finally {
      assertThat(tokensStore.get("default").getDancerState(), is(NO_TOKEN));
    }
  }

  @Test
  public void invalidateContext() {
    assertThat(tokensStore.size(), is(1));
    dancer.invalidateContext();
    assertThat(tokensStore.size(), is(0));

    ResourceOwnerOAuthContext context = dancer.getContext();

    assertThat(context.getAccessToken(), is(nullValue()));
    assertThat(context.getRefreshToken(), is(nullValue()));

    verify(listener).onTokenInvalidated();
  }

  @Test
  public void addListener() throws Exception {
    PlatformManagedOAuthStateListener addedListener = mock(PlatformManagedOAuthStateListener.class);
    dancer.addListener(addedListener);

    invalidateContext();
    verify(addedListener).onTokenInvalidated();

    clearInvocations(listener);
    accessToken();
    assertAccessTokenListened(ACCESS_TOKEN, addedListener);

    refreshToken();
    assertRefreshTokenListened(addedListener);
  }

  @Test
  public void getConnectionDescriptor() throws Exception {
    PlatformManagedConnectionDescriptor descriptor = mock(PlatformManagedConnectionDescriptor.class);
    when(ocsClient.getConnectionDescriptor(CONNECTION_URI)).thenReturn(completedFuture(descriptor));

    assertThat(dancer.getConnectionDescriptor().get(), is(sameInstance(descriptor)));
  }

  @Test
  public void removeListener() throws Exception {
    dancer.removeListener(listener);
    clearInvocations(listener);

    dancer.invalidateContext();
    dancer.accessToken().get();
    dancer.refreshToken().get();

    verifyNoInteractions(listener);
  }

  private void mockResponseBody(HttpResponse response, String body, String revisionToken) {
    reset(response);
    when(response.getEntity()).thenAnswer(inv -> new InputStreamHttpEntity(new ByteArrayInputStream(body.getBytes())));
    when(response.getStatusCode()).thenReturn(OK.getStatusCode());
    when(response.getHeaderValue(CONTENT_TYPE)).thenReturn(APPLICATION_JSON.toRfcString());
    when(response.getHeaderValue(REVISION_TOKEN_HEADER)).thenReturn(revisionToken);
  }

  private Answer<?> delayedResponse(HttpResponse response) {
    return inv -> {
      final CompletableFuture<HttpResponse> responseFuture = new CompletableFuture<>();

      httpClientCallbackExecutor.execute(() -> {
        try {
          sleep(10);
        } catch (InterruptedException e) {
          currentThread().interrupt();
          responseFuture.completeExceptionally(e);
        }
        responseFuture.complete(response);
      });

      return responseFuture;
    };
  }

  protected OAuthPlatformManagedDancerBuilder basePlatformOAuthDancerBuilder(Map<String, ?> tokensStore) {
    return basePlatformOAuthDancerBuilder(tokensStore, null);
  }

  protected OAuthPlatformManagedDancerBuilder basePlatformOAuthDancerBuilder(Map<String, ?> tokensStore, String apiVersion) {
    final OAuthPlatformManagedDancerBuilder builder = new DefaultOAuthPlatformManagedDancerBuilder(service,
                                                                                                   ocsClientFactory,
                                                                                                   schedulerService,
                                                                                                   lockFactory,
                                                                                                   (Map<String, ResourceOwnerOAuthContext>) tokensStore,
                                                                                                   httpClientFactory,
                                                                                                   expressionLanguage);
    service.platformManagedOAuthDancerBuilder(lockFactory, tokensStore, expressionLanguage);

    builder
        .connectionUri(CONNECTION_URI)
        .platformUrl(SERVICE_URL)
        .organizationId(ORG_ID)
        .apiVersion(apiVersion)
        .addListener(listener)
        .tokenUrl(AUTH_URL)
        .withClientCredentialsIn(BODY)
        .encoding(UTF_8)
        .clientCredentials(CLIENT_ID, CLIENT_SECRET);

    return builder;
  }

  private DefaultPlatformManagedDancer createDancer(OAuthDancerBuilder<PlatformManagedOAuthDancer> builder) throws MuleException {
    try {
      DefaultPlatformManagedDancer dancer = (DefaultPlatformManagedDancer) startDancer(builder);
      verify(ocsClient).getAccessToken(CONNECTION_URI);
      assertAccessTokenListened(ACCESS_TOKEN);
      return dancer;
    } catch (MuleException e) {
      dancer = null;
      throw e;
    }
  }

  private void assertAccessTokenListened(String accessToken) {
    assertAccessTokenListened(accessToken, listener);
  }

  private void assertAccessTokenListened(String accessToken, PlatformManagedOAuthStateListener listener) {
    ArgumentCaptor<ResourceOwnerOAuthContext> listenerContextCaptor = forClass(ResourceOwnerOAuthContext.class);
    verify(listener).onAccessToken(listenerContextCaptor.capture());
    ResourceOwnerOAuthContext listenerContext = listenerContextCaptor.getValue();

    assertThat(listenerContext.getAccessToken(), equalTo(accessToken));
    assertThat(listenerContext.getRefreshToken(), is(nullValue()));
  }

  private void assertRefreshTokenListened() {
    assertRefreshTokenListened(listener);
  }

  private void assertRefreshTokenListened(PlatformManagedOAuthStateListener listener) {
    ArgumentCaptor<ResourceOwnerOAuthContext> listenerContextCaptor = forClass(ResourceOwnerOAuthContext.class);
    verify(listener).onTokenRefreshed(listenerContextCaptor.capture());
    ResourceOwnerOAuthContext listenerContext = listenerContextCaptor.getValue();

    assertThat(listenerContext.getAccessToken(), equalTo(REFRESHED_ACCESS_TOKEN));
    assertThat(listenerContext.getRefreshToken(), is(nullValue()));
  }
}
