/*
 * Copyright (c) 2022 SAP SE or an SAP affiliate company. All rights reserved.
 */

package com.sap.cloud.sdk.cloudplatform.security;

import java.security.interfaces.RSAPublicKey;
import java.time.Duration;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.function.Supplier;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest;

import com.auth0.jwt.JWT;
import com.auth0.jwt.exceptions.JWTDecodeException;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.google.common.net.HttpHeaders;
import com.sap.cloud.sdk.cloudplatform.requestheader.RequestHeaderContainer;
import com.sap.cloud.sdk.cloudplatform.resilience.ResilienceConfiguration;
import com.sap.cloud.sdk.cloudplatform.resilience.ResilienceConfiguration.CircuitBreakerConfiguration;
import com.sap.cloud.sdk.cloudplatform.resilience.ResilienceConfiguration.TimeLimiterConfiguration;
import com.sap.cloud.sdk.cloudplatform.resilience.ResilienceDecorator;
import com.sap.cloud.sdk.cloudplatform.resilience.ResilienceIsolationMode;
import com.sap.cloud.sdk.cloudplatform.security.exception.AuthTokenAccessException;
import com.sap.cloud.sdk.cloudplatform.security.exception.TokenRequestFailedException;
import com.sap.cloud.security.token.Token;
import com.sap.cloud.security.token.validation.CombiningValidator;
import com.sap.cloud.security.token.validation.ValidationResult;
import com.sap.cloud.security.token.validation.Validator;
import com.sap.cloud.security.xsuaa.tokenflows.XsuaaTokenFlows;

import io.vavr.control.Option;
import io.vavr.control.Try;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;

@Slf4j
class AuthTokenDecoder
{
    // We handle the value of Authorization header case insensitive. To prevent repetitive toLowerCase calls we store this prefix directly in lower case.
    private static final String BEARER_PREFIX = "bearer ";

    @Getter
    @Nonnull
    private final OAuth2TokenServiceCache tokenServiceCache;

    @Getter
    @Nonnull
    private final List<CombiningValidator<Token>> tokenValidators;

    AuthTokenDecoder()
    {
        this(DefaultAuthTokenFacade.DEFAULT_TOKEN_SERVICE_CACHE, DefaultAuthTokenFacade.DEFAULT_VALIDATORS);
    }

    /**
     * Default constructor for the AuthTokenDecoder.
     * 
     * @param serviceCache
     *            Optional parameter for {@code OAuth2ServiceProvider} reference. If {@code null}, then
     *            {@link OAuth2TokenServiceCache} will be instantiated.
     * @param validators
     *            Optional parameter for a list of token validators. If {@code null}, then validators for XSUAA/IAS will
     *            be inferred from current application context.
     */
    public AuthTokenDecoder(
        @Nullable final OAuth2TokenServiceCache serviceCache,
        @Nullable final List<CombiningValidator<Token>> validators )
    {
        tokenServiceCache = serviceCache != null ? serviceCache : OAuth2TokenServiceCache.create();
        tokenValidators = validators != null ? validators : DefaultAuthTokenFacade.loadOauth2Validators();
    }

    @Nonnull
    String getRefreshToken( @Nonnull final String encodedJwt, @Nonnull final String refreshToken )
    {
        final ResilienceConfiguration resilienceConfiguration =
            ResilienceConfiguration
                .of(AuthTokenDecoder.class)
                .isolationMode(ResilienceIsolationMode.NO_ISOLATION)
                .timeLimiterConfiguration(TimeLimiterConfiguration.of().timeoutDuration(Duration.ofSeconds(6)))
                .circuitBreakerConfiguration(CircuitBreakerConfiguration.of().waitDuration(Duration.ofSeconds(6)));

        final Supplier<String> tokenSupplier =
            ResilienceDecorator
                .decorateSupplier(
                    () -> sendRefreshTokenRequestAndParseResponse(encodedJwt, refreshToken),
                    resilienceConfiguration);

        return tokenSupplier.get();
    }

    private
        String
        sendRefreshTokenRequestAndParseResponse( @Nonnull final String encodedJwt, @Nonnull final String refreshToken )
    {
        final DecodedJWT accessToken = JWT.decode(encodedJwt);
        final XsuaaTokenFlows tokenFlows =
            OAuth2ServiceProvider
                .builder()
                .tokenServiceCache(tokenServiceCache)
                .staticAccessToken(accessToken)
                .build()
                .getXsuaaTokenFlows();

        return Try
            .of(() -> tokenFlows.refreshTokenFlow().refreshToken(refreshToken).execute().getAccessToken())
            .onFailure(e -> log.debug("Failed for access token {} and refresh token {}.", encodedJwt, refreshToken, e))
            .getOrElseThrow(e -> new TokenRequestFailedException("Refresh JWT request failed", e));
    }

    @Nonnull
    AuthToken decodeAndValidate( @Nonnull final String encodedJwt, @Nullable final String refreshToken )
        throws AuthTokenAccessException
    {
        // prepare an exception and add suppressed failures but only throw it if all validation attempts fail
        final AuthTokenAccessException preparedException = new AuthTokenAccessException("Failed to verify JWT bearer.");

        // first attempt validation the Cloud Security library
        final Optional<AuthToken> validatedToken =
            tokenValidators
                .stream()
                .map(tokenValidator -> Try.of(() -> validateJwtWithSecurityLibrary(encodedJwt, tokenValidator)))
                .peek(tokenValidation -> tokenValidation.onFailure(preparedException::addSuppressed))
                .peek(tokenValidation -> tokenValidation.onFailure(e -> log.debug("JWT validation attempt failed.", e)))
                .filter(Try::isSuccess)
                .findFirst()
                .map(Try::get)
                .map(AuthToken::new);

        if( validatedToken.isPresent() ) {
            return validatedToken.get();
        }

        // if the validation with Cloud Security library fails, try the legacy SDK implementation
        if( tokenValidators.isEmpty() ) {
            log
                .warn(
                    "AuthTokenDecoder was instantiated without a token validator. Falling back to legacy mode for token validation.");
        } else {
            log
                .warn(
                    "Access token validation failed. Falling back to legacy mode. Issuer and JKU properties are not supported.");
        }

        return validateJwtViaLegacyImplementation(encodedJwt, refreshToken)
            .map(AuthToken::new)
            .onFailure(preparedException::addSuppressed)
            .onFailure(e -> log.debug("JWT validation attempt failed.", e))
            .getOrElseThrow(t -> preparedException);
    }

    @Nonnull
    Try<AuthToken> decodeAndValidate( @Nonnull final HttpServletRequest request )
    {
        final Enumeration<String> headerValues =
            Option.of(request.getHeaders(HttpHeaders.AUTHORIZATION)).getOrElse(Collections::emptyEnumeration);
        return fromAuthorizationHeaders(Collections.list(headerValues));
    }

    @Nonnull
    Try<AuthToken> decodeAndValidate( @Nonnull final RequestHeaderContainer headers )
    {
        return fromAuthorizationHeaders(headers.getHeaderValues(HttpHeaders.AUTHORIZATION));
    }

    @Nonnull
    private Try<AuthToken> fromAuthorizationHeaders( @Nonnull final Collection<String> headerValues )
    {
        if( headerValues.isEmpty() ) {
            return Try
                .failure(
                    new AuthTokenAccessException(
                        "Failed to decode JWT bearer: no '"
                            + HttpHeaders.AUTHORIZATION
                            + "' header present in request."));
        }

        if( headerValues.size() > 1 ) {
            return Try
                .failure(
                    new AuthTokenAccessException(
                        "Failed to decode JWT bearer: multiple '"
                            + HttpHeaders.AUTHORIZATION
                            + "' headers present in request."));
        }

        final String authorizationValue = headerValues.stream().findFirst().get();

        if( !authorizationValue.toLowerCase(Locale.ENGLISH).startsWith(BEARER_PREFIX) ) {
            return Try
                .failure(
                    new AuthTokenAccessException(
                        "Failed to decode JWT bearer: no JWT bearer present in '"
                            + HttpHeaders.AUTHORIZATION
                            + "' header of request."));
        }

        final String tokenValue = authorizationValue.substring(BEARER_PREFIX.length());
        return Try.of(() -> decodeAndValidate(tokenValue, null));
    }

    private DecodedJWT validateJwtWithSecurityLibrary(
        @Nonnull final String encodedJwt,
        @Nonnull final Validator<Token> tokenValidator )
    {
        final Token token = Token.create(encodedJwt);
        final ValidationResult result = tokenValidator.validate(token);
        if( result.isValid() ) {
            return JWT.decode(encodedJwt);
        }
        throw new AuthTokenAccessException("The token is invalid: " + result.getErrorDescription());
    }

    private
        Try<DecodedJWT>
        validateJwtViaLegacyImplementation( @Nonnull final String encodedJwt, @Nullable final String refreshToken )
    {
        final DecodedJWT decodedJWT;
        final List<RSAPublicKey> verificationKeys;

        try {
            decodedJWT = JWT.decode(encodedJwt);
            verificationKeys = AuthTokenValidator.getVerificationPublicKeysForJwt(decodedJWT);
        }
        catch( final JWTDecodeException | AuthTokenAccessException e ) {
            return Try.failure(e);
        }

        final String targetAlgorithm = decodedJWT.getAlgorithm();
        final AuthTokenValidator authTokenValidator = new AuthTokenValidator(targetAlgorithm, verificationKeys);

        Optional<DecodedJWT> verifiedToken = authTokenValidator.verifyToken(encodedJwt);

        // fallback: use refresh token to request new access token and try to verify once more
        if( !verifiedToken.isPresent() && refreshToken != null ) {
            final String refreshedEncodedJwt = getRefreshToken(encodedJwt, refreshToken);
            verifiedToken = authTokenValidator.verifyToken(refreshedEncodedJwt);
        }
        return Try.of(verifiedToken::get);
    }
}
