/*
 * (c) 2003-2022 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package com.mulesoft.connectivity.rest.commons.api.connection.oauth;

import static org.mule.runtime.api.util.Preconditions.checkArgument;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.UNAUTHORIZED;
import static org.mule.runtime.http.api.HttpHeaders.Names.AUTHORIZATION;

import org.mule.runtime.api.connection.ConnectionValidationResult;
import org.mule.runtime.extension.api.connectivity.oauth.AccessTokenExpiredException;
import org.mule.runtime.extension.api.connectivity.oauth.OAuthState;
import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.client.HttpRequestOptions;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.request.HttpRequestBuilder;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;

import com.mulesoft.connectivity.rest.commons.api.connection.DefaultRestConnection;

import java.io.IOException;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;

/**
 * Specialization of {@link DefaultRestConnection} for resources protected with OAuth, regardless of the grant type.
 */
public class OAuthRestConnection extends DefaultRestConnection {

  private final OAuthState oauthState;
  private final String resourceOwnerId;

  public OAuthRestConnection(HttpClient httpClient,
                             HttpRequestOptions httpRequestOptions, String baseUri,
                             OAuthState oauthState, String resourceOwnerId) {
    super(httpClient, assertNoAuthenticationProvided(httpRequestOptions), baseUri);

    checkArgument(resourceOwnerId != null, "resourceOwnerId cannot be null");
    this.resourceOwnerId = resourceOwnerId;

    checkArgument(oauthState != null, "oauthState cannot be null");
    this.oauthState = oauthState;
  }

  private static HttpRequestOptions assertNoAuthenticationProvided(HttpRequestOptions httpRequestOptions) {
    if (httpRequestOptions.getAuthentication().isPresent()) {
      throw new IllegalStateException("OAuthRestConnection should not be created with an HttpRequestOptions that provides an Authentication");
    }
    return httpRequestOptions;
  }

  @Override
  protected void authenticate(HttpRequestBuilder httpRequestBuilder) {
    httpRequestBuilder.addHeader(AUTHORIZATION, "Bearer " + oauthState.getAccessToken());
  }

  @Override
  public CompletableFuture<HttpResponse> sendAsync(HttpRequest request) {
    CompletableFuture<HttpResponse> future = new CompletableFuture<>();

    super.sendAsync(request).whenComplete((response, exception) -> {
      try {
        if (exception != null) {
          future.completeExceptionally(exception);
          return;
        }

        try {
          throwAccessTokenExpiredExceptionIfNeedsRefreshToken(response.getStatusCode(), resourceOwnerId);
          future.complete(response);
        } catch (AccessTokenExpiredException e) {
          future.completeExceptionally(e);
          return;
        }

      } catch (Exception e) {
        future.completeExceptionally(e);
      }
    });

    return future;
  }

  @Override
  public HttpResponse send(HttpRequest request) throws IOException, TimeoutException {
    HttpResponse httpResponse = super.send(request);
    throwAccessTokenExpiredExceptionIfNeedsRefreshToken(httpResponse.getStatusCode(), resourceOwnerId);
    return httpResponse;
  }

  @Override
  public ConnectionValidationResult validate(HttpRequest request,
                                             Function<HttpResponse, ConnectionValidationResult> whenComplete,
                                             Function<Exception, ConnectionValidationResult> onError) {
    return super.validate(request,
                          httpResponse -> {
                            throwAccessTokenExpiredExceptionIfNeedsRefreshToken(httpResponse.getStatusCode(),
                                                                                resourceOwnerId);
                            return whenComplete.apply(httpResponse);
                          },
                          onError);
  }

  /**
   * Template method to allow custom code for server that are not complaint with HTTP status codes to trigger a refresh token.
   * Default implementation is based on {@link org.mule.runtime.http.api.HttpConstants.HttpStatus#UNAUTHORIZED} status code to
   * throw the {@link AccessTokenExpiredException} so Mule Runtime/SDK will refresh the access token. <br/>
   *
   * @param statusCode the {@link HttpResponse} from the server.
   * @param resourceOwnerId the oauth resource owner id to be provided by the {@link AccessTokenExpiredException}.
   * @return a {@link HttpResponse} to be returned by the connection to the caller.
   */
  protected void throwAccessTokenExpiredExceptionIfNeedsRefreshToken(int statusCode, String resourceOwnerId) {
    if (UNAUTHORIZED.getStatusCode() == statusCode) {
      throw new AccessTokenExpiredException(resourceOwnerId);
    }
  }

}
