/*
 * (c) 2003-2023 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package com.mulesoft.service.oauth.internal.platform;

import static com.mulesoft.service.oauth.internal.platform.OCSClient.REVISION_TOKEN_QUERY_PARAM;
import static com.mulesoft.service.oauth.internal.platform.OCSClient.getFactory;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Collections.synchronizedList;
import static java.util.UUID.randomUUID;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.Executors.newFixedThreadPool;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.sameInstance;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.Assert.assertThat;
import static org.junit.rules.ExpectedException.none;
import static org.mockito.Answers.RETURNS_DEEP_STUBS;
import static org.mockito.ArgumentCaptor.forClass;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.MockitoAnnotations.initMocks;
import static org.mule.runtime.api.metadata.MediaType.APPLICATION_JSON;
import static org.mule.runtime.core.internal.util.ConcurrencyUtils.exceptionallyCompleted;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.FORBIDDEN;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.INTERNAL_SERVER_ERROR;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.OK;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.UNAUTHORIZED;
import static org.mule.runtime.http.api.HttpHeaders.Names.AUTHORIZATION;
import static org.mule.runtime.http.api.HttpHeaders.Names.CONTENT_TYPE;
import static org.mule.runtime.oauth.api.builder.ClientCredentialsLocation.BODY;
import static org.mule.runtime.oauth.internal.AbstractOAuthDancer.TOKEN_REQUEST_TIMEOUT_MILLIS;

import org.mule.runtime.api.el.MuleExpressionLanguage;
import org.mule.runtime.api.exception.DefaultMuleException;
import org.mule.runtime.api.exception.MuleException;
import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.api.lifecycle.Initialisable;
import org.mule.runtime.api.lifecycle.Startable;
import org.mule.runtime.api.lock.LockFactory;
import org.mule.runtime.api.util.Reference;
import org.mule.runtime.api.util.concurrent.Latch;
import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.client.HttpRequestOptions;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.runtime.oauth.api.ClientCredentialsOAuthDancer;
import org.mule.runtime.oauth.api.OAuthService;
import org.mule.runtime.oauth.api.PlatformManagedConnectionDescriptor;
import org.mule.runtime.oauth.api.builder.OAuthClientCredentialsDancerBuilder;
import org.mule.runtime.oauth.api.builder.OAuthDancerBuilder;
import org.mule.runtime.oauth.api.state.ResourceOwnerOAuthContext;
import org.mule.runtime.test.oauth.AbstractOAuthTestCase;

import java.io.ByteArrayInputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;

import javax.inject.Inject;

import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.stubbing.Answer;

@RunWith(MockitoJUnitRunner.class)
public class OCSClientTestCase extends AbstractOAuthTestCase {

  private static final String SERVICE_URL = "https://cs.anypoint.com/ocs";
  private static final String AUTH_URL = "https://cs.anypoint.com/token";
  private static final String CLIENT_ID = "ocsClient";
  private static final String CLIENT_SECRET = "ocsSecret";
  private static final String ORG_ID = "fooOrg";
  private static final String ANOTHER_API_VERSION = "v80";
  private static final String CONNECTION_ID = randomUUID().toString();
  private static final String CONNECTION_NAME = "My Connection";
  private static final String CONNECTION_URI = "ocs:123122-123123123121-934834394/instagram/soylamorza-2fun4o3wf";
  private static final String FULLNAME_PARAMETER = "Bond... James Bond";
  private static final String AGE_PARAMETER = "33";
  private static final String REQUEST_METHOD = "GET";
  private static final String REQUEST_PATH = "/test";
  private static final String NUMBER = "1";
  private static final int THREAD_COUNT = 5;
  private static final String PLATFORM_ACCESS_TOKEN = "coreServicesToken";
  private static final String REFRESHED_PLATFORM_ACCESS_TOKEN = "refreshed-coreServicesToken";
  private static final String API_VERSION = "v1";
  private static final String API_PREFIX = "/api/" + API_VERSION;

  private static final String CONNECTION_DESCRIPTOR_RESPONSE = format("{\n" +
      "   \"id\":\"%s\",\n" +
      "   \"displayName\":\"%s\",\n" +
      "   \"uri\":\"%s\",\n" +
      "   \"parameters\":{\n" +
      "      \"fullname\":\"%s\",\n" +
      "      \"age\":\"%s\",\n" +
      "      \"connectivityTest\":{\n" +
      "         \"requestMethod\":\"GET\",\n" +
      "         \"requestPath\":\"/test\"\n" +
      "      },\n" +
      "      \"someConnectivityTests\":[\n" +
      "         {\n" +
      "            \"requestMethod\":\"GET\",\n" +
      "            \"requestPath\":\"/test\"\n" +
      "         },\n" +
      "         {\n" +
      "            \"requestMethod\":\"GET\",\n" +
      "            \"requestPath\":\"/test\"\n" +
      "         },\n" +
      "         {\n" +
      "            \"requestMethod\":\"GET\",\n" +
      "            \"requestPath\":\"/test\"\n" +
      "         }\n" +
      "      ],\n" +
      "      \"someNumbers\":[\n" +
      "         \"1\",\n" +
      "         \"1\",\n" +
      "         \"1\"\n" +
      "      ]\n" +
      "   }\n" +
      "}",
                                                                      CONNECTION_ID, CONNECTION_NAME, CONNECTION_URI,
                                                                      FULLNAME_PARAMETER, AGE_PARAMETER);

  private static final String CONNECTION_DESCRIPTOR_RESPONSE_WITH_EXTRA_DATA = format("{\n" +
      "   \"id\":\"%s\",\n" +
      "   \"displayName\":\"%s\",\n" +
      "   \"uri\":\"%s\",\n" +
      "   \"extraParameter\":\"extraParameterValue\",\n" +
      "   \"parameters\":{\n" +
      "      \"fullname\":\"%s\",\n" +
      "      \"age\":\"%s\",\n" +
      "      \"connectivityTest\":{\n" +
      "         \"requestMethod\":\"GET\",\n" +
      "         \"requestPath\":\"/test\"\n" +
      "      },\n" +
      "      \"someConnectivityTests\":[\n" +
      "         {\n" +
      "            \"requestMethod\":\"GET\",\n" +
      "            \"requestPath\":\"/test\"\n" +
      "         },\n" +
      "         {\n" +
      "            \"requestMethod\":\"GET\",\n" +
      "            \"requestPath\":\"/test\"\n" +
      "         },\n" +
      "         {\n" +
      "            \"requestMethod\":\"GET\",\n" +
      "            \"requestPath\":\"/test\"\n" +
      "         }\n" +
      "      ],\n" +
      "      \"someNumbers\":[\n" +
      "         \"1\",\n" +
      "         \"1\",\n" +
      "         \"1\"\n" +
      "      ]\n" +
      "   }\n" +
      "}",
                                                                                      CONNECTION_ID, CONNECTION_NAME,
                                                                                      CONNECTION_URI,
                                                                                      FULLNAME_PARAMETER, AGE_PARAMETER);
  private static final String OCS_ERROR_MESSAGE = "OCS Error message";

  private static final String ERROR_RESPONSE_WITH_MESSAGE = "{\n"
      + "\"message\": \" " + OCS_ERROR_MESSAGE + "\",\n"
      + "\"serviceProviderResponse\": {\n"
      + "\"errorMessage\": \"The call failed!\",\n"
      + "\"code\": \"U12380\"\n"
      + "}\n"
      + "}";

  private static final String ERROR_RESPONSE_WITHOUT_MESSAGE = "{\n"
      + "\"serviceProviderResponse\": {\n"
      + "\"errorMessage\": \"The call failed!\",\n"
      + "\"code\": \"U12380\"\n"
      + "}\n"
      + "}";

  @Mock
  private HttpClient httpClient;

  @Mock
  private LockFactory lockFactory;

  @Mock
  private OAuthService oauthService;

  @Mock(extraInterfaces = {Initialisable.class, Startable.class})
  private ClientCredentialsOAuthDancer ccDancer;

  @Mock(answer = RETURNS_DEEP_STUBS)
  private HttpResponse httpResponse;

  @Mock
  private ResourceOwnerOAuthContext resourceOwnerOAuthContext;

  @Inject
  private MuleExpressionLanguage expressionLanguage;

  @Rule
  public ExpectedException expectedException = none();

  private OAuthClientCredentialsDancerBuilder ccDancerBuilder;
  private Map<String, ResourceOwnerOAuthContext> tokenStore = new ConcurrentHashMap<>();
  private ExecutorService executorService = newFixedThreadPool(THREAD_COUNT);
  private OCSClient ocsClient;
  private Reference<String> refreshedToken = new Reference<>();
  private OCSSettings settings;

  @Override
  public void setupServices() throws Exception {
    initMocks(this);
    super.setupServices();

    expressionLanguage = spy(expressionLanguage);
    settings = new OCSSettings(SERVICE_URL, AUTH_URL, CLIENT_ID, CLIENT_SECRET, UTF_8, BODY, ORG_ID);
    ccDancerBuilder = mock(OAuthClientCredentialsDancerBuilder.class, (Answer<Object>) invocation -> {
      if (invocation.getMethod().getName().equals("build")) {
        return ccDancer;
      }

      if (OAuthDancerBuilder.class.isAssignableFrom(invocation.getMethod().getReturnType())) {
        return ccDancerBuilder;
      }

      return null;
    });
    when(oauthService.clientCredentialsGrantTypeDancerBuilder(lockFactory, tokenStore, expressionLanguage))
        .thenReturn(ccDancerBuilder);

    when(ccDancer.accessToken()).thenReturn(completedFuture(PLATFORM_ACCESS_TOKEN));
    when(ccDancer.refreshToken()).thenAnswer(inv -> {
      refreshedToken.set(REFRESHED_PLATFORM_ACCESS_TOKEN);
      return completedFuture(null);
    });
    when(ccDancer.getContext()).thenReturn(resourceOwnerOAuthContext);

    ocsClient = getFactory().create(httpClient, settings, expressionLanguage, oauthService);

    when(httpClient.sendAsync(any(), any())).thenReturn(completedFuture(httpResponse));
    when(httpResponse.getStatusCode()).thenReturn(200);
  }

  @Override
  public void teardownServices() {
    try {
      executorService.shutdown();
      verify(httpClient, never()).stop();
      ocsClient.stop();
      verify(httpClient).stop();
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
  }

  @Test
  public void getAccessTokenUrl() {
    final String expected =
        SERVICE_URL + API_PREFIX + "/organizations/" + ORG_ID + "/connections/" + CONNECTION_URI + "/token";
    assertThat(ocsClient.getAccessTokenUrl(CONNECTION_URI), equalTo(expected));

    settings = new OCSSettings(SERVICE_URL + "/", AUTH_URL, CLIENT_ID, CLIENT_SECRET, UTF_8, BODY, ORG_ID);
    ocsClient = getFactory().create(httpClient, settings, expressionLanguage, oauthService);

    assertThat(ocsClient.getAccessTokenUrl(CONNECTION_URI), equalTo(expected));
  }

  @Test
  public void getRefreshTokenUrl() {
    final String expected =
        SERVICE_URL + API_PREFIX + "/organizations/" + ORG_ID + "/connections/" + CONNECTION_URI + "/token";
    assertThat(ocsClient.getRefreshTokenUrl(CONNECTION_URI), equalTo(expected));

    settings = new OCSSettings(SERVICE_URL + "/", AUTH_URL, CLIENT_ID, CLIENT_SECRET, UTF_8, BODY, ORG_ID);
    ocsClient = getFactory().create(httpClient, settings, expressionLanguage, oauthService);

    assertThat(ocsClient.getRefreshTokenUrl(CONNECTION_URI), equalTo(expected));
  }

  @Test
  public void getAccessTokenUrlWithCustomApiVersion() {
    final String expected =
        SERVICE_URL + "/api/" + ANOTHER_API_VERSION + "/organizations/" + ORG_ID + "/connections/" + CONNECTION_URI + "/token";

    settings = new OCSSettings(SERVICE_URL + "/", AUTH_URL, CLIENT_ID, CLIENT_SECRET, UTF_8, BODY, ORG_ID, ANOTHER_API_VERSION);
    ocsClient = getFactory().create(httpClient, settings, expressionLanguage, oauthService);

    assertThat(ocsClient.getAccessTokenUrl(CONNECTION_URI), equalTo(expected));
  }

  @Test
  public void getRefreshTokenUrlWithCustomApiVersion() {
    final String expected =
        SERVICE_URL + "/api/" + ANOTHER_API_VERSION + "/organizations/" + ORG_ID + "/connections/" + CONNECTION_URI + "/token";

    settings = new OCSSettings(SERVICE_URL + "/", AUTH_URL, CLIENT_ID, CLIENT_SECRET, UTF_8, BODY, ORG_ID, ANOTHER_API_VERSION);
    ocsClient = getFactory().create(httpClient, settings, expressionLanguage, oauthService);

    assertThat(ocsClient.getRefreshTokenUrl(CONNECTION_URI), equalTo(expected));
  }

  @Test
  public void initCSDancer() throws Exception {
    initCSDancer(1);
  }

  @Test
  public void initCSDancerTwice() throws Exception {
    initCSDancer(2);
  }

  @Test
  public void failureToInitCSDancer() throws Exception {
    MuleException e = new DefaultMuleException("");
    doThrow(e).when((Startable) ccDancer).start();

    expectedException.expect(MuleRuntimeException.class);
    expectedException.expectCause(is(sameInstance(e)));

    initCSDancer(1);
  }

  @Test
  public void initCSDancerConcurrently() throws Exception {
    Latch mainLatch = new Latch();
    CountDownLatch taskLatch = new CountDownLatch(THREAD_COUNT);
    CountDownLatch finishLatch = new CountDownLatch(THREAD_COUNT);

    List<Exception> exceptions = synchronizedList(new ArrayList<>(THREAD_COUNT));

    for (int i = 0; i < THREAD_COUNT; i++) {
      executorService.submit(() -> {
        taskLatch.countDown();
        try {
          mainLatch.await();
        } catch (Exception e) {
          throw new RuntimeException(e);
        }

        try {
          ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);
        } catch (Exception e) {
          exceptions.add(e);
        } finally {
          finishLatch.countDown();
        }
      });
    }

    taskLatch.await();
    mainLatch.release();
    finishLatch.await(30, SECONDS);

    assertThat(exceptions.isEmpty(), is(true));
    verifyCsDancerInitialization();
  }

  @Test
  public void getAccessToken() throws Exception {
    ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);

    HttpResponse response = ocsClient.getAccessToken(CONNECTION_URI).get();
    assertThat(response, is(sameInstance(httpResponse)));

    ArgumentCaptor<HttpRequest> requestCaptor = forClass(HttpRequest.class);
    ArgumentCaptor<HttpRequestOptions> optionsCaptor = forClass(HttpRequestOptions.class);

    verify(httpClient).sendAsync(requestCaptor.capture(), optionsCaptor.capture());

    HttpRequest request = requestCaptor.getValue();
    HttpRequestOptions options = optionsCaptor.getValue();

    assertThat(request.getHeaderValue(AUTHORIZATION), equalTo("bearer " + PLATFORM_ACCESS_TOKEN));
    assertThat(request.getUri().toString(), equalTo(ocsClient.getAccessTokenUrl(CONNECTION_URI)));
    assertThat(request.getMethod(), equalTo("GET"));
    assertThat(options.getResponseTimeout(), equalTo(TOKEN_REQUEST_TIMEOUT_MILLIS));
  }

  @Test
  public void getAccessTokenWithExceptionFromCCDancer() throws Exception {
    Exception e = new Exception();
    when(ccDancer.accessToken()).thenReturn(exceptionallyCompleted(e));

    ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);
    expectedException.expect(ExecutionException.class);
    expectedException.expectCause(is(sameInstance(e)));

    ocsClient.getAccessToken(CONNECTION_URI).get();
  }

  @Test
  public void httpClientThrowsExceptionGettingAccessToken() throws Exception {
    Exception e = new RuntimeException();
    when(httpClient.sendAsync(any(), any())).thenThrow(e);

    ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);
    expectedException.expect(ExecutionException.class);
    expectedException.expectCause(is(sameInstance(e)));

    ocsClient.getAccessToken(CONNECTION_URI).get();
  }

  @Test
  public void getAccessTokenWithoutInitDancer() throws Exception {
    expectedException.expect(IllegalStateException.class);
    expectedException.expectMessage("Core Services Dancer not yet initialized");

    ocsClient.getAccessToken(CONNECTION_URI);
  }

  @Test
  public void getAccessTokenWithExpiredCSToken() throws Exception {
    when(httpResponse.getStatusCode())
        .thenReturn(UNAUTHORIZED.getStatusCode())
        .thenReturn(OK.getStatusCode());

    when(resourceOwnerOAuthContext.getAccessToken()).thenAnswer(inv -> refreshedToken.get());

    ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);

    HttpResponse response = ocsClient.getAccessToken(CONNECTION_URI).get();
    assertThat(response, is(sameInstance(httpResponse)));

    ArgumentCaptor<HttpRequest> requestCaptor = forClass(HttpRequest.class);
    ArgumentCaptor<HttpRequestOptions> optionsCaptor = forClass(HttpRequestOptions.class);

    InOrder inOrder = inOrder(httpClient, ccDancer, resourceOwnerOAuthContext);

    inOrder.verify(httpClient).sendAsync(any(), any());
    inOrder.verify(ccDancer).refreshToken();
    inOrder.verify(resourceOwnerOAuthContext).getAccessToken();
    inOrder.verify(httpClient).sendAsync(requestCaptor.capture(), optionsCaptor.capture());

    HttpRequest request = requestCaptor.getValue();
    HttpRequestOptions options = optionsCaptor.getValue();

    assertThat(request.getHeaderValue(AUTHORIZATION), equalTo("bearer " + REFRESHED_PLATFORM_ACCESS_TOKEN));
    assertThat(request.getUri().toString(), equalTo(ocsClient.getAccessTokenUrl(CONNECTION_URI)));
    assertThat(request.getMethod(), equalTo("GET"));
    assertThat(options.getResponseTimeout(), equalTo(TOKEN_REQUEST_TIMEOUT_MILLIS));
  }

  @Test
  public void getAccessTokenWithExpiredCSTokenAndFailToRefresh() throws Exception {
    final Exception refreshException = new Exception();

    when(httpResponse.getStatusCode())
        .thenReturn(UNAUTHORIZED.getStatusCode())
        .thenReturn(OK.getStatusCode());

    when(ccDancer.refreshToken()).thenReturn(exceptionallyCompleted(refreshException));
    ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);

    try {
      ocsClient.getAccessToken(CONNECTION_URI).get();
    } catch (ExecutionException e) {
      assertThat(e.getCause(), is(sameInstance(refreshException)));

      InOrder inOrder = inOrder(httpClient, ccDancer);

      inOrder.verify(httpClient).sendAsync(any(), any());
      inOrder.verify(ccDancer).refreshToken();
    }
  }

  @Test
  public void getAccessTokenReturnsBadStatusCode() throws Exception {
    expectedException.expectMessage(OCS_ERROR_MESSAGE);
    when(httpResponse.getStatusCode()).thenReturn(FORBIDDEN.getStatusCode());
    when(httpResponse.getEntity().getContent())
        .thenAnswer(inv -> {
          return new ByteArrayInputStream(ERROR_RESPONSE_WITH_MESSAGE.getBytes());
        });
    ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);

    ocsClient.getAccessToken(CONNECTION_URI).get();
  }

  @Test
  public void getAccessTokenReturnsBadStatusCodeAndOcsReturnsNoMessage() throws Exception {
    expectedException
        .expectMessage("Got status code 403 when trying when making a request to : https://cs.anypoint.com/ocs/api/v1/organizations/fooOrg/connections/ocs:123122-123123123121-934834394/instagram/soylamorza-2fun4o3wf/token");
    when(httpResponse.getStatusCode()).thenReturn(FORBIDDEN.getStatusCode());
    when(httpResponse.getEntity().getContent())
        .thenAnswer(inv -> {
          return new ByteArrayInputStream(ERROR_RESPONSE_WITHOUT_MESSAGE.getBytes());
        });
    ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);

    ocsClient.getAccessToken(CONNECTION_URI).get();
  }

  @Test
  public void refreshToken() throws Exception {
    ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);
    final String revisionToken = "1";

    HttpResponse response = ocsClient.refreshToken(CONNECTION_URI, revisionToken).get();
    assertThat(response, is(sameInstance(httpResponse)));

    ArgumentCaptor<HttpRequest> requestCaptor = forClass(HttpRequest.class);
    ArgumentCaptor<HttpRequestOptions> optionsCaptor = forClass(HttpRequestOptions.class);

    verify(httpClient).sendAsync(requestCaptor.capture(), optionsCaptor.capture());

    HttpRequest request = requestCaptor.getValue();
    HttpRequestOptions options = optionsCaptor.getValue();

    assertThat(request.getHeaderValue(AUTHORIZATION), equalTo("bearer " + PLATFORM_ACCESS_TOKEN));
    assertThat(request.getUri().toString(), equalTo(ocsClient.getRefreshTokenUrl(CONNECTION_URI)));
    assertThat(request.getMethod(), equalTo("POST"));
    assertThat(request.getQueryParams().get(REVISION_TOKEN_QUERY_PARAM), equalTo(revisionToken));
    assertThat(options.getResponseTimeout(), equalTo(TOKEN_REQUEST_TIMEOUT_MILLIS));
  }

  @Test
  public void refreshTokenWithExceptionFromCCDancer() throws Exception {
    Exception e = new Exception();
    when(ccDancer.accessToken()).thenReturn(exceptionallyCompleted(e));

    ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);
    expectedException.expect(ExecutionException.class);
    expectedException.expectCause(is(sameInstance(e)));

    ocsClient.refreshToken(CONNECTION_URI, "").get();
  }

  @Test
  public void httpClientThrowsExceptionRefreshingToken() throws Exception {
    Exception e = new RuntimeException();
    when(httpClient.sendAsync(any(), any())).thenThrow(e);

    ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);
    expectedException.expect(ExecutionException.class);
    expectedException.expectCause(is(sameInstance(e)));

    ocsClient.refreshToken(CONNECTION_URI, "").get();
  }

  @Test
  public void refreshTokenWithoutInitDancer() throws Exception {
    expectedException.expect(IllegalStateException.class);
    expectedException.expectMessage("Core Services Dancer not yet initialized");

    ocsClient.refreshToken(CONNECTION_URI, "");
  }

  @Test
  public void refreshTokenWithExpiredCSToken() throws Exception {
    when(httpResponse.getStatusCode())
        .thenReturn(UNAUTHORIZED.getStatusCode())
        .thenReturn(OK.getStatusCode());

    when(resourceOwnerOAuthContext.getAccessToken()).thenAnswer(inv -> refreshedToken.get());

    ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);

    HttpResponse response = ocsClient.refreshToken(CONNECTION_URI, "").get();
    assertThat(response, is(sameInstance(httpResponse)));

    ArgumentCaptor<HttpRequest> requestCaptor = forClass(HttpRequest.class);
    ArgumentCaptor<HttpRequestOptions> optionsCaptor = forClass(HttpRequestOptions.class);

    InOrder inOrder = inOrder(httpClient, ccDancer, resourceOwnerOAuthContext);

    inOrder.verify(httpClient).sendAsync(any(), any());
    inOrder.verify(ccDancer).refreshToken();
    inOrder.verify(resourceOwnerOAuthContext).getAccessToken();
    inOrder.verify(httpClient).sendAsync(requestCaptor.capture(), optionsCaptor.capture());

    HttpRequest request = requestCaptor.getValue();
    HttpRequestOptions options = optionsCaptor.getValue();

    assertThat(request.getHeaderValue(AUTHORIZATION), equalTo("bearer " + REFRESHED_PLATFORM_ACCESS_TOKEN));
    assertThat(request.getUri().toString(), equalTo(ocsClient.getRefreshTokenUrl(CONNECTION_URI)));
    assertThat(request.getMethod(), equalTo("POST"));
    assertThat(options.getResponseTimeout(), equalTo(TOKEN_REQUEST_TIMEOUT_MILLIS));
  }

  @Test
  public void refreshTokenWithExpiredCSTokenAndFailToRefresh() throws Exception {
    final Exception refreshException = new Exception();

    when(httpResponse.getStatusCode())
        .thenReturn(UNAUTHORIZED.getStatusCode())
        .thenReturn(OK.getStatusCode());

    when(ccDancer.refreshToken()).thenReturn(exceptionallyCompleted(refreshException));
    ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);

    try {
      ocsClient.refreshToken(CONNECTION_URI, "").get();
    } catch (ExecutionException e) {
      assertThat(e.getCause(), is(sameInstance(refreshException)));

      InOrder inOrder = inOrder(httpClient, ccDancer);

      inOrder.verify(httpClient).sendAsync(any(), any());
      inOrder.verify(ccDancer).refreshToken();
    }
  }

  @Test
  public void refreshTokenReturnsBadStatusCode() throws Exception {
    expectedException.expectMessage(OCS_ERROR_MESSAGE);
    when(httpResponse.getStatusCode()).thenReturn(FORBIDDEN.getStatusCode());
    when(httpResponse.getEntity().getContent())
        .thenAnswer(inv -> {
          return new ByteArrayInputStream(ERROR_RESPONSE_WITH_MESSAGE.getBytes());
        });
    ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);

    ocsClient.refreshToken(CONNECTION_URI, "").get();
  }

  private void initCSDancer(int count) throws Exception {
    for (int i = 0; i < count; i++) {
      ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);
    }

    verifyCsDancerInitialization();
  }

  @Test
  public void getConnectionDescriptor() throws Exception {
    assertGetConnectionDescriptor(CONNECTION_DESCRIPTOR_RESPONSE);
  }

  @Test
  public void getConnectionDescriptorWithExtraData() throws Exception {
    assertGetConnectionDescriptor(CONNECTION_DESCRIPTOR_RESPONSE_WITH_EXTRA_DATA);
  }

  @Test
  public void getConnectionDescriptorWithInvalidStatusCode() throws Exception {
    ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);
    when(httpResponse.getStatusCode()).thenReturn(INTERNAL_SERVER_ERROR.getStatusCode());

    expectedException.expect(ExecutionException.class);
    expectedException.expectCause(instanceOf(MuleRuntimeException.class));

    ocsClient.getConnectionDescriptor(CONNECTION_URI).get();
  }

  private void verifyCsDancerInitialization() throws MuleException {
    verify(ccDancerBuilder).name("OCS@" + SERVICE_URL);
    verify(ccDancerBuilder).encoding(UTF_8);
    verify(ccDancerBuilder).clientCredentials(CLIENT_ID, CLIENT_SECRET);
    verify(ccDancerBuilder).tokenUrl(AUTH_URL);
    verify(ccDancerBuilder).withClientCredentialsIn(BODY);

    verify(ccDancerBuilder).build();
    verify(((Initialisable) ccDancer)).initialise();
    verify(((Startable) ccDancer)).start();
  }

  private void assertGetConnectionDescriptor(String getConnectionDescriptorResponse) throws Exception {
    when(httpResponse.getStatusCode()).thenReturn(OK.getStatusCode());
    when(httpResponse.getEntity().getContent())
        .thenReturn(new ByteArrayInputStream(getConnectionDescriptorResponse.getBytes()));
    when(httpResponse.getHeaderValue(CONTENT_TYPE)).thenReturn(APPLICATION_JSON.toRfcString());

    ocsClient.initCoreServicesDancer(lockFactory, tokenStore, expressionLanguage);

    PlatformManagedConnectionDescriptor descriptor = ocsClient.getConnectionDescriptor(CONNECTION_URI).get();
    assertThat(descriptor.getId(), equalTo(CONNECTION_ID));
    assertThat(descriptor.getDisplayName(), equalTo(CONNECTION_NAME));
    assertThat(descriptor.getUri(), equalTo(CONNECTION_URI));

    Map<String, Object> params = descriptor.getParameters();
    assertThat(params, is(notNullValue()));
    assertThat(params.size(), is(5));
    assertThat(params.get("fullname"), equalTo(FULLNAME_PARAMETER));
    assertThat(params.get("age"), equalTo(AGE_PARAMETER));
    Map<String, String> connectivityTest = (Map) params.get("connectivityTest");
    assertThat(connectivityTest.keySet(), hasSize(2));
    assertThat(connectivityTest.get("requestMethod"), equalTo(REQUEST_METHOD));
    assertThat(connectivityTest.get("requestPath"), equalTo(REQUEST_PATH));
    List<Map<String, String>> someConnectivityTests = (List) params.get("someConnectivityTests");
    assertThat(someConnectivityTests, hasSize(3));
    someConnectivityTests.forEach(someConnectivityTest -> {
      assertThat(someConnectivityTest.keySet(), hasSize(2));
      assertThat(someConnectivityTest.get("requestMethod"), equalTo(REQUEST_METHOD));
      assertThat(someConnectivityTest.get("requestPath"), equalTo(REQUEST_PATH));
    });
    List<String> someNumbers = (List) params.get("someNumbers");
    assertThat(someNumbers, hasSize(3));
    someNumbers.forEach(someNumber -> {
      assertThat(someNumber, equalTo(NUMBER));
    });
  }
}
