001
002package io.vrap.rmf.base.client.http;
003
004import java.time.Duration;
005import java.util.concurrent.*;
006import java.util.function.Function;
007
008import io.vrap.rmf.base.client.*;
009import io.vrap.rmf.base.client.error.UnauthorizedException;
010import io.vrap.rmf.base.client.oauth2.AuthException;
011import io.vrap.rmf.base.client.oauth2.TokenSupplier;
012
013import dev.failsafe.*;
014import dev.failsafe.event.ExecutionAttemptedEvent;
015import dev.failsafe.spi.Scheduler;
016
017/**
018 * Default implementation for the {@link OAuthMiddleware} with automatic retry upon expired access
019 */
020public class OAuthMiddlewareImpl implements AutoCloseable, OAuthMiddleware {
021    private final OAuthHandler authHandler;
022    private static final InternalLogger logger = InternalLogger.getLogger(TokenSupplier.LOGGER_AUTH);
023    private final FailsafeExecutor<ApiHttpResponse<byte[]>> failsafeExecutor;
024
025    public OAuthMiddlewareImpl(final OAuthHandler oAuthHandler) {
026        this(Scheduler.DEFAULT, oAuthHandler, 1, false);
027    }
028
029    public OAuthMiddlewareImpl(final OAuthHandler oauthHandler, final int maxRetries, final boolean useCircuitBreaker) {
030        this(Scheduler.DEFAULT, oauthHandler, maxRetries, useCircuitBreaker);
031    }
032
033    public OAuthMiddlewareImpl(final ScheduledExecutorService executorService, final OAuthHandler oauthHandler,
034            final int maxRetries, final boolean useCircuitBreaker) {
035        this(Scheduler.of(executorService), oauthHandler, maxRetries, useCircuitBreaker);
036    }
037
038    public OAuthMiddlewareImpl(final ExecutorService executor, final OAuthHandler oauthHandler, final int maxRetries,
039            final boolean useCircuitBreaker) {
040        this(Scheduler.of(executor), oauthHandler, maxRetries, useCircuitBreaker);
041    }
042
043    public OAuthMiddlewareImpl(final Scheduler scheduler, final OAuthHandler oauthHandler, final int maxRetries,
044            final boolean useCircuitBreaker) {
045        this.authHandler = oauthHandler;
046
047        RetryPolicy<ApiHttpResponse<byte[]>> retry = RetryPolicy.<ApiHttpResponse<byte[]>> builder()
048                .handleIf((response, throwable) -> {
049                    if (throwable != null) {
050                        return throwable instanceof UnauthorizedException;
051                    }
052                    return response.getStatusCode() == 401;
053                })
054                .onRetry(event -> {
055                    logger.debug(() -> "Refresh Bearer token #" + event.getAttemptCount());
056                    authHandler.refreshToken();
057                })
058                .withMaxRetries(maxRetries)
059                .build();
060        if (useCircuitBreaker) {
061            final Fallback<ApiHttpResponse<byte[]>> fallback = Fallback
062                    .builderOfException((ExecutionAttemptedEvent<? extends ApiHttpResponse<byte[]>> event) -> {
063                        logger.debug(() -> "Convert CircuitBreakerOpenException to AuthException");
064                        logger.trace(event::getLastException);
065                        return new AuthException(400, "", null, "Authentication failed", null,
066                            event.getLastException());
067                    })
068                    .handleIf(throwable -> throwable instanceof CircuitBreakerOpenException)
069                    .build();
070
071            final CircuitBreaker<ApiHttpResponse<byte[]>> circuitBreaker = CircuitBreaker
072                    .<ApiHttpResponse<byte[]>> builder()
073                    .handleIf((response, throwable) -> {
074                        Throwable cause = throwable instanceof CompletionException ? throwable.getCause() : throwable;
075                        if (cause instanceof AuthException) {
076                            return ((AuthException) throwable.getCause()).getResponse().getStatusCode() == 400;
077                        }
078                        return response.getStatusCode() == 400;
079                    })
080                    .withDelayFn(context -> Duration
081                            .ofMillis(Math.min(100 * context.getAttemptCount() * context.getAttemptCount(), 15000)))
082                    .withFailureThreshold(5, Duration.ofMinutes(1))
083                    .withSuccessThreshold(2)
084                    .onClose(event -> logger.debug(() -> "The authentication circuit breaker was closed"))
085                    .onOpen(event -> logger.debug(() -> "The authentication circuit breaker was opened"))
086                    .onHalfOpen(event -> logger.debug(() -> "The authentication circuit breaker was half-opened"))
087                    .onFailure(event -> logger.trace(() -> "Authentication failed", event.getException()))
088                    .build();
089            this.failsafeExecutor = Failsafe.with(fallback, retry, circuitBreaker).with(scheduler);
090        }
091        else {
092            this.failsafeExecutor = Failsafe.with(retry).with(scheduler);
093        }
094    }
095
096    @Override
097    public CompletableFuture<ApiHttpResponse<byte[]>> invoke(final ApiHttpRequest request,
098            final Function<ApiHttpRequest, CompletableFuture<ApiHttpResponse<byte[]>>> next) {
099        return failsafeExecutor.getStageAsync(() -> {
100            if (request.getHeaders().getFirst(ApiHttpHeaders.AUTHORIZATION) != null) {
101                return next.apply(request);
102            }
103            AuthenticationToken token = authHandler.getToken();
104            return next.apply(request.addHeader(ApiHttpHeaders.AUTHORIZATION, OAuthHandler.authHeader(token)));
105        });
106    }
107
108    @Override
109    public void close() {
110        authHandler.close();
111    }
112}