/*
 * (c) 2003-2019 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package com.mulesoft.modules.oauth2.provider;

import static com.mulesoft.modules.oauth2.provider.api.Constants.RequestGrantType.AUTHORIZATION_CODE;
import static com.mulesoft.modules.oauth2.provider.api.Constants.RequestGrantType.PASSWORD;
import static java.util.Collections.singletonMap;
import static java.util.concurrent.Executors.newFixedThreadPool;
import static java.util.concurrent.TimeUnit.SECONDS;
import static net.smartam.leeloo.client.request.OAuthClientRequest.tokenLocation;
import static net.smartam.leeloo.common.message.types.GrantType.REFRESH_TOKEN;
import static org.apache.commons.lang3.RandomStringUtils.randomAlphanumeric;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.BAD_REQUEST;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.MOVED_TEMPORARILY;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.OK;
import static org.mule.runtime.http.api.HttpHeaders.Names.AUTHORIZATION;

import com.mulesoft.modules.oauth2.provider.api.Constants.RequestGrantType;
import com.mulesoft.modules.oauth2.provider.api.token.AccessTokenStoreHolder;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;

import net.smartam.leeloo.client.request.OAuthClientRequest;
import net.smartam.leeloo.common.exception.OAuthSystemException;
import net.smartam.leeloo.common.message.types.GrantType;
import org.apache.commons.httpclient.HttpClient;
import org.apache.commons.httpclient.methods.GetMethod;
import org.apache.commons.httpclient.methods.PostMethod;

import org.junit.Test;

public class OAuth2ProviderModuleRefreshTokenTestCase extends AbstractOAuth2ProviderModuleTestCase {

  private static final String PROTECTED_RESOURCE_PATH = "/protected";

  @Override
  protected String doGetConfigFile() {
    return "oauth2-refresh-token-http-config.xml";
  }

  @Override
  protected void doSetUp() throws Exception {
    super.doSetUp();
    client.getAuthorizedGrantTypes().add(AUTHORIZATION_CODE);
    client.getAuthorizedGrantTypes().add(RequestGrantType.REFRESH_TOKEN);
    updateClientInOS();
  }

  @Test
  public void tokenExchangeSuccess() throws Exception {
    final OAuthClientRequest oAuthClientRequest = tokenLocation(getTokenEndpointURL())
        .setGrantType(GrantType.AUTHORIZATION_CODE)
        .setCode(TEST_AUTHORIZATION_CODE)
        .setClientId(TEST_CLIENT_ID)
        .setRedirectURI(TEST_REDIRECT_URI)
        .buildBodyMessage();

    oAuthClientRequest
        .setHeaders(singletonMap(AUTHORIZATION, getValidBasicAuthHeaderValue(TEST_CLIENT_ID, TEST_CLIENT_PASSWORD)));

    final PostMethod postToken = postOAuthClientRequestExpectingStatus(oAuthClientRequest, OK.getStatusCode());

    validateSuccessfulTokenResponseNoScope(getContentAsMap(postToken), true);
  }

  @Test
  public void refreshTokenMissingToken() throws Exception {
    final String accessToken = randomAlphanumeric(20);
    final String refreshToken = randomAlphanumeric(20);
    addAccessTokenToStore(accessToken, refreshToken);

    final OAuthClientRequest oAuthClientRequest = tokenLocation(getTokenEndpointURL())
        .setGrantType(REFRESH_TOKEN)
        .buildBodyMessage();

    oAuthClientRequest
        .setHeaders(singletonMap(AUTHORIZATION, getValidBasicAuthHeaderValue(TEST_CLIENT_ID, TEST_CLIENT_PASSWORD)));

    postOAuthClientRequestExpectingStatus(oAuthClientRequest, BAD_REQUEST.getStatusCode());
  }

  @Test
  public void refreshTokenInvalidScope() throws Exception {
    final String accessToken = randomAlphanumeric(20);
    final String refreshToken = randomAlphanumeric(20);
    addAccessTokenToStore(accessToken, refreshToken);

    final OAuthClientRequest oAuthClientRequest = tokenLocation(getTokenEndpointURL())
        .setGrantType(REFRESH_TOKEN)
        .setRefreshToken(refreshToken)
        .setScope(TEST_SCOPE)
        .buildBodyMessage();

    oAuthClientRequest
        .setHeaders(singletonMap(AUTHORIZATION, getValidBasicAuthHeaderValue(TEST_CLIENT_ID, TEST_CLIENT_PASSWORD)));

    final PostMethod postToken = postOAuthClientRequestExpectingStatus(oAuthClientRequest, BAD_REQUEST.getStatusCode());

    assertEqualJsonObj("{\"error\":\"invalid_scope\",\"error_description\":\"\"}",
                       postToken);
  }

  @Test
  public void refreshTokenSuccess() throws Exception {
    final String accessToken = randomAlphanumeric(20);
    final String refreshToken = randomAlphanumeric(20);
    addAccessTokenToStore(accessToken, refreshToken);

    final OAuthClientRequest oAuthClientRequest = tokenLocation(getTokenEndpointURL())
        .setGrantType(REFRESH_TOKEN)
        .setRefreshToken(refreshToken)
        .buildBodyMessage();

    oAuthClientRequest
        .setHeaders(singletonMap(AUTHORIZATION, getValidBasicAuthHeaderValue(TEST_CLIENT_ID, TEST_CLIENT_PASSWORD)));

    final PostMethod postToken = postOAuthClientRequestExpectingStatus(oAuthClientRequest,
                                                                       OK.getStatusCode());

    validateSuccessfulTokenResponseNoScope(getContentAsMap(postToken), true);
  }

  @Test
  public void refreshTokenReceivingGrantedScopeSuccess() throws Exception {
    client.getScopes().add(USER_SCOPE);
    updateClientInOS();

    final String accessToken = randomAlphanumeric(20);
    final String refreshToken = randomAlphanumeric(20);
    final AccessTokenStoreHolder authorizationCodeStoreHolder = addAccessTokenToStore(accessToken, refreshToken);
    authorizationCodeStoreHolder.getAuthorizationRequest().getScopes().add(USER_SCOPE);
    authorizationCodeStoreHolder.getAccessToken().getScopes().add(USER_SCOPE);
    updateAccessTokenHolderInOS(authorizationCodeStoreHolder);

    final OAuthClientRequest oAuthClientRequest = tokenLocation(getTokenEndpointURL())
        .setGrantType(REFRESH_TOKEN)
        .setRefreshToken(refreshToken)
        .buildBodyMessage();

    oAuthClientRequest
        .setHeaders(singletonMap(AUTHORIZATION, getValidBasicAuthHeaderValue(TEST_CLIENT_ID, TEST_CLIENT_PASSWORD)));

    final PostMethod postToken = postOAuthClientRequestExpectingStatus(oAuthClientRequest,
                                                                       OK.getStatusCode());

    validateSuccessfulTokenResponse(getContentAsMap(postToken), USER_SCOPE, true);
  }

  @Test
  public void refreshTokenRequestingGrantedScopeSuccess() throws Exception {
    client.getScopes().add(USER_SCOPE);
    updateClientInOS();

    final String accessToken = randomAlphanumeric(20);
    final String refreshToken = randomAlphanumeric(20);
    final AccessTokenStoreHolder authorizationCodeStoreHolder = addAccessTokenToStore(accessToken, refreshToken);
    authorizationCodeStoreHolder.getAuthorizationRequest().getScopes().add(USER_SCOPE);
    authorizationCodeStoreHolder.getAccessToken().getScopes().add(USER_SCOPE);
    updateAccessTokenHolderInOS(authorizationCodeStoreHolder);

    final OAuthClientRequest oAuthClientRequest = tokenLocation(getTokenEndpointURL() + "WithUserScope")
        .setGrantType(REFRESH_TOKEN)
        .setRefreshToken(refreshToken)
        .setScope(USER_SCOPE)
        .buildBodyMessage();

    oAuthClientRequest
        .setHeaders(singletonMap(AUTHORIZATION, getValidBasicAuthHeaderValue(TEST_CLIENT_ID, TEST_CLIENT_PASSWORD)));

    final PostMethod postToken = postOAuthClientRequestExpectingStatus(oAuthClientRequest,
                                                                       OK.getStatusCode());

    validateSuccessfulTokenResponse(getContentAsMap(postToken), USER_SCOPE, true);
  }

  @Test
  public void refreshTokenRequestingBeyondGrantedScopeFailure() throws Exception {
    client.getScopes().add(USER_SCOPE);
    updateClientInOS();

    final String accessToken = randomAlphanumeric(20);
    final String refreshToken = randomAlphanumeric(20);
    final AccessTokenStoreHolder authorizationCodeStoreHolder = addAccessTokenToStore(accessToken, refreshToken);
    authorizationCodeStoreHolder.getAuthorizationRequest().getScopes().add(USER_SCOPE);
    authorizationCodeStoreHolder.getAccessToken().getScopes().add(USER_SCOPE);

    final OAuthClientRequest oAuthClientRequest = tokenLocation(getTokenEndpointURL() + "WithUserScope")
        .setGrantType(REFRESH_TOKEN)
        .setRefreshToken(refreshToken)
        .setScope(USER_SCOPE + " " + TEST_SCOPE)
        .buildBodyMessage();

    oAuthClientRequest
        .setHeaders(singletonMap(AUTHORIZATION, getValidBasicAuthHeaderValue(TEST_CLIENT_ID, TEST_CLIENT_PASSWORD)));

    final PostMethod postToken = postOAuthClientRequestExpectingStatus(oAuthClientRequest,
                                                                       BAD_REQUEST.getStatusCode());

    assertEqualJsonObj("{\"error\":\"invalid_scope\",\"error_description\":\"\"}",
                       postToken);
  }

  @Test
  public void performAuthorizationCodeGrantOAuth2DanceAndTestRefreshToken() throws Exception {
    final OAuthClientRequest authorizationRequest = OAuthClientRequest.authorizationLocation(
                                                                                             getAuthorizationEndpointUrl())
        .setResponseType("code")
        .setClientId(TEST_CLIENT_ID)
        .setRedirectURI(TEST_REDIRECT_URI)
        .setParameter("username", TEST_RESOURCE_OWNER_USERNAME)
        .setParameter("password", TEST_RESOURCE_OWNER_PASSWORD)
        .buildBodyMessage();

    final PostMethod postCredentials = postOAuthClientRequestExpectingStatus(authorizationRequest,
                                                                             MOVED_TEMPORARILY.getStatusCode());

    final Map<String, List<String>> authorizationResponse = validateSuccessfulLoginResponse(
                                                                                            postCredentials, "code");
    final String authorizationCode = authorizationResponse.get("code").get(0);

    final OAuthClientRequest tokenExchangeRequest = tokenLocation(
                                                                  getTokenEndpointURL())
                                                                      .setGrantType(GrantType.AUTHORIZATION_CODE)
                                                                      .setCode(authorizationCode)
                                                                      .setClientId(TEST_CLIENT_ID)
                                                                      .setClientSecret(TEST_CLIENT_SECRET)
                                                                      .setRedirectURI(TEST_REDIRECT_URI)
                                                                      .buildBodyMessage();

    doGetAccessTokenAndTryRefreshIt(tokenExchangeRequest);
  }

  @Test
  public void performResourceOwnerPasswordCredentialsGrantOAuth2DanceAndTestRefreshToken() throws Exception {
    client.getAuthorizedGrantTypes().add(PASSWORD);
    updateClientInOS();

    final OAuthClientRequest tokenExchangeRequest = tokenLocation(
                                                                  getTokenEndpointURL())
                                                                      .setGrantType(GrantType.PASSWORD)
                                                                      .setParameter("username", TEST_RESOURCE_OWNER_USERNAME)
                                                                      .setParameter("password", TEST_RESOURCE_OWNER_PASSWORD)
                                                                      .buildBodyMessage();

    tokenExchangeRequest
        .setHeaders(singletonMap(AUTHORIZATION, getValidBasicAuthHeaderValue(TEST_CLIENT_ID, TEST_CLIENT_PASSWORD)));

    doGetAccessTokenAndTryRefreshIt(tokenExchangeRequest);
  }

  private void doGetAccessTokenAndTryRefreshIt(final OAuthClientRequest tokenExchangeRequest)
      throws IOException, OAuthSystemException, InterruptedException {
    PostMethod postToken = postOAuthClientRequestExpectingStatus(tokenExchangeRequest, OK.getStatusCode());

    Map<String, Object> tokenResponse = validateSuccessfulTokenResponseNoScope(getContentAsMap(postToken), true);
    final String accessToken1 = (String) tokenResponse.get("access_token");
    final String refreshToken1 = (String) tokenResponse.get("refresh_token");

    GetMethod getProtectedResource = new GetMethod(getProtectedResourceURL(PROTECTED_RESOURCE_PATH)
        + "?access_token=" + accessToken1);
    executeHttpMethodExpectingStatus(getProtectedResource, OK.getStatusCode());
    assertThat(getProtectedResource.getResponseBodyAsString(), is(equalTo(PROTECTED_RESOURCE_CONTENT)));

    final OAuthClientRequest refreshTokenRequest = tokenLocation(getTokenEndpointURL())
        .setGrantType(REFRESH_TOKEN)
        .setRefreshToken(refreshToken1)
        .buildBodyMessage();
    refreshTokenRequest
        .setHeaders(singletonMap(AUTHORIZATION, getValidBasicAuthHeaderValue(TEST_CLIENT_ID, TEST_CLIENT_PASSWORD)));

    postToken = postOAuthClientRequestExpectingStatus(refreshTokenRequest, OK.getStatusCode());

    tokenResponse = validateSuccessfulTokenResponseNoScope(getContentAsMap(postToken), true);
    final String accessToken2 = (String) tokenResponse.get("access_token");
    final String refreshToken2 = (String) tokenResponse.get("refresh_token");

    assertThat(accessToken2, is(not(equalTo(accessToken1))));
    assertThat(refreshToken2, is(not(equalTo(refreshToken1))));

    getProtectedResource = new GetMethod(getProtectedResourceURL(PROTECTED_RESOURCE_PATH)
        + "?access_token=" + accessToken2);
    executeHttpMethodExpectingStatus(getProtectedResource, OK.getStatusCode());
    assertThat(getProtectedResource.getResponseBodyAsString(), is(equalTo(PROTECTED_RESOURCE_CONTENT)));
  }

  @Test
  public void concurrentRefreshTokenSuccess() throws Exception {
    final String accessToken = randomAlphanumeric(20);
    final String refreshToken = randomAlphanumeric(20);
    addAccessTokenToStore(accessToken, refreshToken);

    final OAuthClientRequest oAuthClientRequest = tokenLocation(getTokenEndpointURL())
        .setGrantType(REFRESH_TOKEN)
        .setRefreshToken(refreshToken)
        .buildBodyMessage();

    oAuthClientRequest
        .setHeaders(singletonMap(AUTHORIZATION, getValidBasicAuthHeaderValue(TEST_CLIENT_ID, TEST_CLIENT_PASSWORD)));

    int numberOfThreads = 100;
    final CountDownLatch interleavingLatch = new CountDownLatch(numberOfThreads);
    final List<PostMethod> results = new ArrayList<PostMethod>(numberOfThreads);
    ExecutorService executorService = newFixedThreadPool(numberOfThreads);
    for (int i = 0; i < numberOfThreads; i++) {
      executorService.execute(() -> {

        PostMethod method = null;
        try {
          interleavingLatch.await();
          method = getPostOAuthClientRequest(oAuthClientRequest);
          new HttpClient().executeMethod(method);
        } catch (Exception e) {
        } finally {
          synchronized (results) {
            results.add(method);
          }
        }
      });
      interleavingLatch.countDown();
    }
    executorService.shutdown();
    executorService.awaitTermination(getTestTimeoutSecs(), SECONDS);

    List<PostMethod> okResults = new ArrayList<PostMethod>();
    List<PostMethod> failureResults = new ArrayList<PostMethod>();

    for (PostMethod method : results) {
      if (method.getStatusCode() == OK.getStatusCode()) {
        okResults.add(method);
      } else {
        failureResults.add(method);
      }
    }

    assertThat(okResults.size(), is(1));
    assertThat(failureResults.size(), is(numberOfThreads - 1));
  }
}
