/*
 * Decompiled with CFR 0.152.
 */
package io.strimzi.kafka.oauth.validator;

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import io.strimzi.kafka.oauth.common.HttpUtil;
import io.strimzi.kafka.oauth.common.JSONUtil;
import io.strimzi.kafka.oauth.common.PrincipalExtractor;
import io.strimzi.kafka.oauth.common.TimeUtil;
import io.strimzi.kafka.oauth.common.TokenInfo;
import io.strimzi.kafka.oauth.validator.BackOffTaskScheduler;
import io.strimzi.kafka.oauth.validator.DaemonThreadFactory;
import io.strimzi.kafka.oauth.validator.ECDSASignatureVerifierContext;
import io.strimzi.kafka.oauth.validator.TokenExpiredException;
import io.strimzi.kafka.oauth.validator.TokenSignatureException;
import io.strimzi.kafka.oauth.validator.TokenValidationException;
import io.strimzi.kafka.oauth.validator.TokenValidator;
import java.net.URI;
import java.net.URISyntaxException;
import java.security.Key;
import java.security.Provider;
import java.security.PublicKey;
import java.security.Security;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLSocketFactory;
import org.apache.kafka.common.utils.Time;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.keycloak.TokenVerifier;
import org.keycloak.crypto.AsymmetricSignatureVerifierContext;
import org.keycloak.crypto.KeyWrapper;
import org.keycloak.crypto.SignatureVerifierContext;
import org.keycloak.exceptions.TokenSignatureInvalidException;
import org.keycloak.jose.jwk.JSONWebKeySet;
import org.keycloak.jose.jwk.JWK;
import org.keycloak.representations.AccessToken;
import org.keycloak.util.JWKSUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JWTSignatureValidator
implements TokenValidator {
    private static final Logger log = LoggerFactory.getLogger(JWTSignatureValidator.class);
    private static AtomicBoolean bouncyInstalled = new AtomicBoolean(false);
    private static final TokenVerifier.TokenTypeCheck TOKEN_TYPE_CHECK = new TokenVerifier.TokenTypeCheck("Bearer");
    private final BackOffTaskScheduler fastScheduler;
    private final URI keysUri;
    private final String issuerUri;
    private final int maxStaleSeconds;
    private final boolean checkAccessTokenType;
    private final String audience;
    private final SSLSocketFactory socketFactory;
    private final HostnameVerifier hostnameVerifier;
    private final PrincipalExtractor principalExtractor;
    private long lastFetchTime;
    private Map<String, PublicKey> cache = Collections.emptyMap();
    private Map<String, PublicKey> oldCache = Collections.emptyMap();

    public JWTSignatureValidator(String keysEndpointUri, SSLSocketFactory socketFactory, HostnameVerifier verifier, PrincipalExtractor principalExtractor, String validIssuerUri, int refreshSeconds, int refreshMinPauseSeconds, int expirySeconds, boolean checkAccessTokenType, String audience, boolean enableBouncyCastleProvider, int bouncyCastleProviderPosition) {
        if (keysEndpointUri == null) {
            throw new IllegalArgumentException("keysEndpointUri == null");
        }
        try {
            this.keysUri = new URI(keysEndpointUri);
        }
        catch (URISyntaxException e) {
            throw new IllegalArgumentException("Invalid keysEndpointUri: " + keysEndpointUri, e);
        }
        if (socketFactory != null && !"https".equals(this.keysUri.getScheme())) {
            throw new IllegalArgumentException("SSL socket factory set but keysEndpointUri not 'https'");
        }
        this.socketFactory = socketFactory;
        if (verifier != null && !"https".equals(this.keysUri.getScheme())) {
            throw new IllegalArgumentException("Certificate hostname verifier set but keysEndpointUri not 'https'");
        }
        this.hostnameVerifier = verifier;
        this.principalExtractor = principalExtractor;
        if (validIssuerUri != null) {
            try {
                new URI(validIssuerUri);
            }
            catch (URISyntaxException e) {
                throw new IllegalArgumentException("Value of validIssuerUri not a valid URI: " + validIssuerUri, e);
            }
        }
        this.issuerUri = validIssuerUri;
        this.validateRefreshConfig(refreshSeconds, expirySeconds);
        this.maxStaleSeconds = expirySeconds;
        this.checkAccessTokenType = checkAccessTokenType;
        this.audience = audience;
        if (enableBouncyCastleProvider && !bouncyInstalled.getAndSet(true)) {
            int installedPosition = Security.insertProviderAt((Provider)new BouncyCastleProvider(), bouncyCastleProviderPosition);
            log.info("BouncyCastle security provider installed at position: " + installedPosition);
            if (log.isDebugEnabled()) {
                StringBuilder sb = new StringBuilder("Installed security providers:\n");
                for (Provider p : Security.getProviders()) {
                    sb.append("  - " + p.toString() + "  [" + p.getClass().getName() + "]\n");
                    sb.append("   " + p.getInfo() + "\n");
                }
                log.debug(sb.toString());
            }
        }
        this.fetchKeys();
        ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor(new DaemonThreadFactory());
        this.fastScheduler = new BackOffTaskScheduler(executor, refreshMinPauseSeconds, refreshSeconds, () -> this.fetchKeys());
        this.setupRefreshKeysJob(executor, refreshSeconds);
        if (log.isDebugEnabled()) {
            log.debug("Configured JWTSignatureValidator:\n    keysEndpointUri: " + keysEndpointUri + "\n    sslSocketFactory: " + socketFactory + "\n    hostnameVerifier: " + this.hostnameVerifier + "\n    principalExtractor: " + principalExtractor + "\n    validIssuerUri: " + validIssuerUri + "\n    certsRefreshSeconds: " + refreshSeconds + "\n    certsRefreshMinPauseSeconds: " + refreshMinPauseSeconds + "\n    certsExpirySeconds: " + expirySeconds + "\n    checkAccessTokenType: " + checkAccessTokenType + "\n    enableBouncyCastleProvider: " + enableBouncyCastleProvider + "\n    bouncyCastleProviderPosition: " + bouncyCastleProviderPosition);
        }
    }

    private void validateRefreshConfig(int refreshSeconds, int expirySeconds) {
        if (refreshSeconds <= 0) {
            throw new IllegalArgumentException("refreshSeconds has to be a positive number - (refreshSeconds=" + refreshSeconds + ")");
        }
        if (expirySeconds < refreshSeconds + 60) {
            throw new IllegalArgumentException("expirySeconds has to be at least 60 seconds longer than refreshSeconds - (expirySeconds=" + expirySeconds + ", refreshSeconds=" + refreshSeconds + ")");
        }
    }

    private void setupRefreshKeysJob(ScheduledExecutorService executor, int refreshSeconds) {
        executor.scheduleAtFixedRate(() -> {
            try {
                this.fastScheduler.scheduleTask();
            }
            catch (Exception e) {
                log.error(e.getMessage(), (Throwable)e);
            }
        }, refreshSeconds, refreshSeconds, TimeUnit.SECONDS);
    }

    private PublicKey getPublicKey(String id) {
        return this.getKeyUnlessStale(id);
    }

    private PublicKey getKeyUnlessStale(String id) {
        if (this.lastFetchTime + (long)this.maxStaleSeconds * 1000L > System.currentTimeMillis()) {
            PublicKey result = this.cache.get(id);
            if (result == null) {
                log.warn("No public key for id: " + id);
            }
            return result;
        }
        log.warn("The cached public key with id '" + id + "' is expired!");
        return null;
    }

    private void fetchKeys() {
        try {
            JSONWebKeySet jwks = HttpUtil.get(this.keysUri, this.socketFactory, this.hostnameVerifier, null, JSONWebKeySet.class);
            Map newCache = JWKSUtils.getKeysForUse((JSONWebKeySet)jwks, (JWK.Use)JWK.Use.SIG);
            newCache = Collections.unmodifiableMap(newCache);
            if (!this.cache.equals(newCache)) {
                log.info("JWKS keys change detected. Keys updated.");
                this.oldCache = this.cache;
                this.cache = newCache;
            }
            this.lastFetchTime = System.currentTimeMillis();
        }
        catch (Exception ex) {
            throw new RuntimeException("Failed to fetch public keys needed to validate JWT signatures: " + this.keysUri, ex);
        }
    }

    @Override
    @SuppressFBWarnings(value={"BC_UNCONFIRMED_CAST_OF_RETURN_VALUE"}, justification="We tell TokenVerifier to parse AccessToken. It will return AccessToken or fail.")
    public TokenInfo validate(String token) {
        AccessToken t;
        TokenVerifier tokenVerifier = TokenVerifier.create((String)token, AccessToken.class);
        if (this.issuerUri != null) {
            tokenVerifier.realmUrl(this.issuerUri);
        }
        if (this.checkAccessTokenType) {
            tokenVerifier.withChecks(new TokenVerifier.Predicate[]{TOKEN_TYPE_CHECK});
        }
        if (this.audience != null) {
            tokenVerifier.audience(this.audience);
        }
        String kid = null;
        try {
            kid = tokenVerifier.getHeader().getKeyId();
        }
        catch (Exception e) {
            throw new TokenValidationException("Token signature validation failed: " + token, e).status(TokenValidationException.Status.INVALID_TOKEN);
        }
        try {
            KeyWrapper keywrap = new KeyWrapper();
            PublicKey pub = this.getPublicKey(kid);
            if (pub == null) {
                if (this.oldCache.get(kid) != null) {
                    throw new TokenValidationException("Token validation failed: The signing key is no longer valid (kid:" + kid + ")");
                }
                this.fastScheduler.scheduleTask();
                throw new TokenValidationException("Token validation failed: Unknown signing key (kid:" + kid + ")");
            }
            keywrap.setPublicKey((Key)pub);
            keywrap.setAlgorithm(tokenVerifier.getHeader().getAlgorithm().name());
            keywrap.setKid(kid);
            log.debug("Signature algorithm used: [{}]", (Object)pub.getAlgorithm());
            ECDSASignatureVerifierContext ctx = JWTSignatureValidator.isAlgorithmEC(pub.getAlgorithm()) ? new ECDSASignatureVerifierContext(keywrap) : new AsymmetricSignatureVerifierContext(keywrap);
            tokenVerifier.verifierContext((SignatureVerifierContext)ctx);
            log.debug("SignatureVerifierContext set to: {}", (Object)ctx);
            tokenVerifier.verify();
            t = (AccessToken)tokenVerifier.getToken();
        }
        catch (TokenSignatureInvalidException e) {
            throw new TokenSignatureException("Signature check failed:", e);
        }
        catch (TokenValidationException e) {
            throw e;
        }
        catch (Exception e) {
            throw new TokenValidationException("Token validation failed:", e);
        }
        long expiresMillis = (long)t.getExpiration() * 1000L;
        if (Time.SYSTEM.milliseconds() > expiresMillis) {
            throw new TokenExpiredException("Token expired at: " + expiresMillis + " (" + TimeUtil.formatIsoDateTimeUTC(expiresMillis) + " UTC)");
        }
        String principal = null;
        if (this.principalExtractor.isConfigured()) {
            principal = this.principalExtractor.getPrincipal(JSONUtil.asJson(t));
        }
        if (principal == null && !this.principalExtractor.isConfigured()) {
            principal = this.principalExtractor.getSub(t);
        }
        if (principal == null) {
            throw new RuntimeException("Failed to extract principal - check usernameClaim, fallbackUsernameClaim configuration");
        }
        return new TokenInfo(t, token, principal);
    }

    private static boolean isAlgorithmEC(String algorithm) {
        return "EC".equals(algorithm) || "ECDSA".equals(algorithm);
    }
}

