/*
 * Decompiled with CFR 0.152.
 */
package io.strimzi.kafka.oauth.validator;

import com.fasterxml.jackson.databind.JsonNode;
import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jwt.SignedJWT;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import io.strimzi.kafka.oauth.common.HttpUtil;
import io.strimzi.kafka.oauth.common.JSONUtil;
import io.strimzi.kafka.oauth.common.PrincipalExtractor;
import io.strimzi.kafka.oauth.common.TimeUtil;
import io.strimzi.kafka.oauth.common.TokenInfo;
import io.strimzi.kafka.oauth.jsonpath.JsonPathFilterQuery;
import io.strimzi.kafka.oauth.validator.BackOffTaskScheduler;
import io.strimzi.kafka.oauth.validator.DaemonThreadFactory;
import io.strimzi.kafka.oauth.validator.TokenExpiredException;
import io.strimzi.kafka.oauth.validator.TokenSignatureException;
import io.strimzi.kafka.oauth.validator.TokenValidationException;
import io.strimzi.kafka.oauth.validator.TokenValidator;
import java.net.URI;
import java.net.URISyntaxException;
import java.security.Key;
import java.security.PublicKey;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLSocketFactory;
import org.apache.kafka.common.utils.Time;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JWTSignatureValidator
implements TokenValidator {
    private static final Logger log = LoggerFactory.getLogger(JWTSignatureValidator.class);
    private static DefaultJWSVerifierFactory verifierFactory = new DefaultJWSVerifierFactory();
    private final BackOffTaskScheduler fastScheduler;
    private final URI keysUri;
    private final String issuerUri;
    private final int maxStaleSeconds;
    private final boolean checkAccessTokenType;
    private final String audience;
    private final JsonPathFilterQuery customClaimMatcher;
    private final SSLSocketFactory socketFactory;
    private final HostnameVerifier hostnameVerifier;
    private final PrincipalExtractor principalExtractor;
    private long lastFetchTime;
    private Map<String, PublicKey> cache = Collections.emptyMap();
    private Map<String, PublicKey> oldCache = Collections.emptyMap();

    public JWTSignatureValidator(String keysEndpointUri, SSLSocketFactory socketFactory, HostnameVerifier verifier, PrincipalExtractor principalExtractor, String validIssuerUri, int refreshSeconds, int refreshMinPauseSeconds, int expirySeconds, boolean checkAccessTokenType, String audience, String customClaimCheck) {
        if (keysEndpointUri == null) {
            throw new IllegalArgumentException("keysEndpointUri == null");
        }
        try {
            this.keysUri = new URI(keysEndpointUri);
        }
        catch (URISyntaxException e) {
            throw new IllegalArgumentException("Invalid keysEndpointUri: " + keysEndpointUri, e);
        }
        if (socketFactory != null && !"https".equals(this.keysUri.getScheme())) {
            throw new IllegalArgumentException("SSL socket factory set but keysEndpointUri not 'https'");
        }
        this.socketFactory = socketFactory;
        if (verifier != null && !"https".equals(this.keysUri.getScheme())) {
            throw new IllegalArgumentException("Certificate hostname verifier set but keysEndpointUri not 'https'");
        }
        this.hostnameVerifier = verifier;
        this.principalExtractor = principalExtractor;
        if (validIssuerUri != null) {
            try {
                new URI(validIssuerUri);
            }
            catch (URISyntaxException e) {
                throw new IllegalArgumentException("Value of validIssuerUri not a valid URI: " + validIssuerUri, e);
            }
        }
        this.issuerUri = validIssuerUri;
        this.validateRefreshConfig(refreshSeconds, expirySeconds);
        this.maxStaleSeconds = expirySeconds;
        this.checkAccessTokenType = checkAccessTokenType;
        this.audience = audience;
        this.customClaimMatcher = this.parseCustomClaimCheck(customClaimCheck);
        this.fetchKeys();
        ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor(new DaemonThreadFactory());
        this.fastScheduler = new BackOffTaskScheduler(executor, refreshMinPauseSeconds, refreshSeconds, () -> this.fetchKeys());
        this.setupRefreshKeysJob(executor, refreshSeconds);
        if (log.isDebugEnabled()) {
            log.debug("Configured JWTSignatureValidator:\n    keysEndpointUri: " + keysEndpointUri + "\n    sslSocketFactory: " + socketFactory + "\n    hostnameVerifier: " + this.hostnameVerifier + "\n    principalExtractor: " + principalExtractor + "\n    validIssuerUri: " + validIssuerUri + "\n    certsRefreshSeconds: " + refreshSeconds + "\n    certsRefreshMinPauseSeconds: " + refreshMinPauseSeconds + "\n    certsExpirySeconds: " + expirySeconds + "\n    checkAccessTokenType: " + checkAccessTokenType + "\n    audience: " + audience + "\n    customClaimCheck: " + customClaimCheck);
        }
    }

    private JsonPathFilterQuery parseCustomClaimCheck(String customClaimCheck) {
        if (customClaimCheck != null) {
            String query = customClaimCheck.trim();
            if (query.length() == 0) {
                throw new IllegalArgumentException("Value of customClaimCheck is empty");
            }
            return JsonPathFilterQuery.parse(query);
        }
        return null;
    }

    private void validateRefreshConfig(int refreshSeconds, int expirySeconds) {
        if (refreshSeconds <= 0) {
            throw new IllegalArgumentException("refreshSeconds has to be a positive number - (refreshSeconds=" + refreshSeconds + ")");
        }
        if (expirySeconds < refreshSeconds + 60) {
            throw new IllegalArgumentException("expirySeconds has to be at least 60 seconds longer than refreshSeconds - (expirySeconds=" + expirySeconds + ", refreshSeconds=" + refreshSeconds + ")");
        }
    }

    private void setupRefreshKeysJob(ScheduledExecutorService executor, int refreshSeconds) {
        executor.scheduleAtFixedRate(() -> {
            try {
                this.fastScheduler.scheduleTask();
            }
            catch (Exception e) {
                log.error(e.getMessage(), (Throwable)e);
            }
        }, refreshSeconds, refreshSeconds, TimeUnit.SECONDS);
    }

    private PublicKey getPublicKey(String id) {
        return this.getKeyUnlessStale(id);
    }

    private PublicKey getKeyUnlessStale(String id) {
        if (this.lastFetchTime + (long)this.maxStaleSeconds * 1000L > System.currentTimeMillis()) {
            PublicKey result = this.cache.get(id);
            if (result == null) {
                log.warn("No public key for id: " + id);
            }
            return result;
        }
        log.warn("The cached public key with id '" + id + "' is expired!");
        return null;
    }

    private void fetchKeys() {
        try {
            Map<String, PublicKey> newCache = new HashMap();
            JWKSet jwks = JWKSet.parse((String)HttpUtil.get(this.keysUri, this.socketFactory, this.hostnameVerifier, null, String.class));
            for (JWK jwk : jwks.getKeys()) {
                PublicKey publicKey;
                if (!jwk.getKeyUse().equals((Object)KeyUse.SIGNATURE)) continue;
                if (jwk instanceof ECKey) {
                    publicKey = ((ECKey)jwk).toPublicKey();
                } else if (jwk instanceof RSAKey) {
                    publicKey = ((RSAKey)jwk).toPublicKey();
                } else {
                    log.warn("Unsupported JWK key type: " + jwk.getKeyType());
                    continue;
                }
                newCache.put(jwk.getKeyID(), publicKey);
            }
            if (!this.cache.equals(newCache = Collections.unmodifiableMap(newCache))) {
                log.info("JWKS keys change detected. Keys updated.");
                this.oldCache = this.cache;
                this.cache = newCache;
            }
            this.lastFetchTime = System.currentTimeMillis();
        }
        catch (Throwable ex) {
            throw new RuntimeException("Failed to fetch public keys needed to validate JWT signatures: " + this.keysUri, ex);
        }
    }

    @Override
    @SuppressFBWarnings(value={"BC_UNCONFIRMED_CAST_OF_RETURN_VALUE"}, justification="We tell TokenVerifier to parse AccessToken. It will return AccessToken or fail.")
    public TokenInfo validate(String token) {
        JsonNode t;
        String kid;
        SignedJWT jwt;
        try {
            jwt = SignedJWT.parse((String)token);
            kid = jwt.getHeader().getKeyID();
        }
        catch (Exception e) {
            throw new TokenValidationException("Token validation failed: Failed to parse JWT: " + token, e).status(TokenValidationException.Status.INVALID_TOKEN);
        }
        try {
            PublicKey pub = this.getPublicKey(kid);
            if (pub == null) {
                if (this.oldCache.get(kid) != null) {
                    throw new TokenValidationException("Token validation failed: The signing key is no longer valid (kid:" + kid + ")");
                }
                this.fastScheduler.scheduleTask();
                throw new TokenValidationException("Token validation failed: Unknown signing key (kid:" + kid + ")");
            }
            if (!jwt.verify(verifierFactory.createJWSVerifier(jwt.getHeader(), (Key)pub))) {
                throw new TokenSignatureException("Signature check failed: Invalid token signature");
            }
            t = JSONUtil.asJson(jwt.getPayload().toJSONObject());
        }
        catch (TokenValidationException e) {
            throw e;
        }
        catch (Exception e) {
            throw new TokenValidationException("Token validation failed", e);
        }
        JsonNode exp = t.get("exp");
        if (exp == null) {
            throw new TokenValidationException("Token validation failed: Expiry not set");
        }
        long expiresMillis = (long)exp.asInt(0) * 1000L;
        if (Time.SYSTEM.milliseconds() > expiresMillis) {
            throw new TokenExpiredException("Token expired at: " + expiresMillis + " (" + TimeUtil.formatIsoDateTimeUTC(expiresMillis) + " UTC)");
        }
        this.validateTokenPayload(t);
        if (this.customClaimMatcher != null && !this.customClaimMatcher.matches(t)) {
            throw new TokenValidationException("Token validation failed: Custom claim check failed");
        }
        String principal = this.extractPrincipal(t);
        return new TokenInfo(t, token, principal);
    }

    private String extractPrincipal(JsonNode tokenJson) {
        String principal = null;
        if (this.principalExtractor.isConfigured()) {
            principal = this.principalExtractor.getPrincipal(tokenJson);
        }
        if (principal == null && !this.principalExtractor.isConfigured()) {
            principal = this.principalExtractor.getSub(tokenJson);
        }
        if (principal == null) {
            throw new RuntimeException("Failed to extract principal - check usernameClaim, fallbackUsernameClaim configuration");
        }
        return principal;
    }

    private void validateTokenPayload(JsonNode token) {
        if (this.issuerUri != null) {
            JsonNode iss = token.get("iss");
            if (iss == null) {
                throw new TokenValidationException("Token validation failed: Issuer not set");
            }
            String issuer = iss.asText();
            if (!this.issuerUri.equals(issuer)) {
                throw new TokenValidationException("Token validation failed: Issuer not allowed: " + issuer);
            }
        }
        if (this.checkAccessTokenType) {
            JsonNode typ = token.get("typ");
            if (typ == null) {
                throw new TokenValidationException("Token validation failed: Token type not set");
            }
            String type = typ.asText();
            if (!"Bearer".equals(type)) {
                throw new TokenValidationException("Token validation failed: Token type not allowed: " + type);
            }
        }
        if (this.audience != null) {
            List<Object> aud;
            JsonNode audNode = token.get("aud");
            List<Object> list = aud = audNode == null ? Collections.emptyList() : JSONUtil.asListOfString(audNode);
            if (!aud.contains(this.audience)) {
                throw new TokenValidationException("Token validation failed: Expected audience not available in the token");
            }
        }
    }
}

