/*
 * Decompiled with CFR 0.152.
 */
package io.strimzi.kafka.oauth.validator;

import com.fasterxml.jackson.databind.JsonNode;
import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jwt.SignedJWT;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import io.strimzi.kafka.oauth.common.HttpUtil;
import io.strimzi.kafka.oauth.common.JSONUtil;
import io.strimzi.kafka.oauth.common.PrincipalExtractor;
import io.strimzi.kafka.oauth.common.TimeUtil;
import io.strimzi.kafka.oauth.common.TokenInfo;
import io.strimzi.kafka.oauth.jsonpath.JsonPathFilterQuery;
import io.strimzi.kafka.oauth.jsonpath.JsonPathQuery;
import io.strimzi.kafka.oauth.metrics.JwksHttpSensorKeyProducer;
import io.strimzi.kafka.oauth.metrics.SensorKeyProducer;
import io.strimzi.kafka.oauth.services.OAuthMetrics;
import io.strimzi.kafka.oauth.services.ServiceException;
import io.strimzi.kafka.oauth.services.Services;
import io.strimzi.kafka.oauth.validator.BackOffTaskScheduler;
import io.strimzi.kafka.oauth.validator.DaemonThreadFactory;
import io.strimzi.kafka.oauth.validator.TokenExpiredException;
import io.strimzi.kafka.oauth.validator.TokenSignatureException;
import io.strimzi.kafka.oauth.validator.TokenValidationException;
import io.strimzi.kafka.oauth.validator.TokenValidator;
import io.strimzi.kafka.oauth.validator.ValidationException;
import java.net.URI;
import java.net.URISyntaxException;
import java.security.Key;
import java.security.PublicKey;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLSocketFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JWTSignatureValidator
implements TokenValidator {
    private static final Logger log = LoggerFactory.getLogger(JWTSignatureValidator.class);
    private static final DefaultJWSVerifierFactory VERIFIER_FACTORY = new DefaultJWSVerifierFactory();
    private final String validatorId;
    private final URI keysUri;
    private final String issuerUri;
    private final int maxStaleSeconds;
    private final boolean checkAccessTokenType;
    private final String audience;
    private final JsonPathFilterQuery customClaimMatcher;
    private final JsonPathQuery groupsQuery;
    private final String groupsDelimiter;
    private final SSLSocketFactory socketFactory;
    private final HostnameVerifier hostnameVerifier;
    private final PrincipalExtractor principalExtractor;
    private final boolean ignoreKeyUse;
    private final int connectTimeout;
    private final int readTimeout;
    private long lastFetchTime;
    private Map<String, PublicKey> cache = Collections.emptyMap();
    private Map<String, PublicKey> oldCache = Collections.emptyMap();
    private BackOffTaskScheduler fastScheduler;
    private final boolean enableMetrics;
    private final OAuthMetrics metrics;
    private final SensorKeyProducer jwksHttpSensorKeyProducer;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public JWTSignatureValidator(String validatorId, String keysEndpointUri, SSLSocketFactory socketFactory, HostnameVerifier verifier, PrincipalExtractor principalExtractor, String groupsClaimQuery, String groupsClaimDelimiter, String validIssuerUri, int refreshSeconds, int refreshMinPauseSeconds, int expirySeconds, boolean ignoreKeyUse, boolean checkAccessTokenType, String audience, String customClaimCheck, int connectTimeoutSeconds, int readTimeoutSeconds, boolean enableMetrics, boolean failFast) {
        if (validatorId == null) {
            throw new IllegalArgumentException("validatorId == null");
        }
        this.validatorId = validatorId;
        if (keysEndpointUri == null) {
            throw new IllegalArgumentException("keysEndpointUri == null");
        }
        try {
            this.keysUri = new URI(keysEndpointUri);
        }
        catch (URISyntaxException e) {
            throw new IllegalArgumentException("Invalid keysEndpointUri: " + keysEndpointUri, e);
        }
        if (socketFactory != null && !"https".equals(this.keysUri.getScheme())) {
            throw new IllegalArgumentException("SSL socket factory set but keysEndpointUri not 'https'");
        }
        this.socketFactory = socketFactory;
        if (verifier != null && !"https".equals(this.keysUri.getScheme())) {
            throw new IllegalArgumentException("Certificate hostname verifier set but keysEndpointUri not 'https'");
        }
        this.hostnameVerifier = verifier;
        this.principalExtractor = principalExtractor;
        if (validIssuerUri != null) {
            try {
                new URI(validIssuerUri);
            }
            catch (URISyntaxException e) {
                throw new IllegalArgumentException("Value of validIssuerUri not a valid URI: " + validIssuerUri, e);
            }
        }
        this.issuerUri = validIssuerUri;
        this.validateRefreshConfig(refreshSeconds, expirySeconds);
        this.maxStaleSeconds = expirySeconds;
        this.checkAccessTokenType = checkAccessTokenType;
        this.audience = audience;
        this.customClaimMatcher = this.parseCustomClaimCheck(customClaimCheck);
        this.groupsQuery = this.parseGroupsQuery(groupsClaimQuery);
        this.groupsDelimiter = this.parseGroupsDelimiter(groupsClaimDelimiter);
        this.connectTimeout = connectTimeoutSeconds;
        this.readTimeout = readTimeoutSeconds;
        this.enableMetrics = enableMetrics;
        this.ignoreKeyUse = ignoreKeyUse;
        try {
            this.metrics = enableMetrics ? Services.getInstance().getMetrics() : null;
            this.jwksHttpSensorKeyProducer = new JwksHttpSensorKeyProducer(validatorId, this.keysUri);
            ScheduledExecutorService executor = this.setupExecutorAndFetchInitialKeys(refreshSeconds, refreshMinPauseSeconds, failFast);
            this.setupRefreshKeysJob(executor, refreshSeconds);
        }
        finally {
            if (log.isDebugEnabled()) {
                log.debug("Configured JWTSignatureValidator:\n    validatorId: " + validatorId + "\n    keysEndpointUri: " + keysEndpointUri + "\n    sslSocketFactory: " + socketFactory + "\n    hostnameVerifier: " + this.hostnameVerifier + "\n    principalExtractor: " + principalExtractor + "\n    groupsClaimQuery: " + groupsClaimQuery + "\n    groupsClaimDelimiter: " + groupsClaimDelimiter + "\n    validIssuerUri: " + validIssuerUri + "\n    certsRefreshSeconds: " + refreshSeconds + "\n    certsRefreshMinPauseSeconds: " + refreshMinPauseSeconds + "\n    certsExpirySeconds: " + expirySeconds + "\n    certsIgnoreKeyUse: " + ignoreKeyUse + "\n    checkAccessTokenType: " + checkAccessTokenType + "\n    audience: " + audience + "\n    customClaimCheck: " + customClaimCheck + "\n    connectTimeoutSeconds: " + connectTimeoutSeconds + "\n    readTimeoutSeconds: " + readTimeoutSeconds + "\n    enableMetrics: " + enableMetrics + "\n    failFast: " + failFast);
            }
        }
    }

    private ScheduledExecutorService setupExecutorAndFetchInitialKeys(int refreshSeconds, int refreshMinPauseSeconds, boolean failFast) {
        boolean initFetchFailed = false;
        try {
            this.fetchKeys();
        }
        catch (Exception e) {
            if (failFast) {
                throw e;
            }
            initFetchFailed = true;
            log.warn("[IGNORED] Fetching JWKS keys has failed, but fail-fast is disabled: ", (Throwable)e);
        }
        ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor(new DaemonThreadFactory());
        this.fastScheduler = new BackOffTaskScheduler(executor, refreshMinPauseSeconds, refreshSeconds, this::fetchKeys);
        if (initFetchFailed) {
            this.fastScheduler.scheduleTask();
        }
        return executor;
    }

    private JsonPathFilterQuery parseCustomClaimCheck(String customClaimCheck) {
        if (customClaimCheck != null) {
            String query = customClaimCheck.trim();
            if (query.length() == 0) {
                throw new IllegalArgumentException("Value of customClaimCheck is empty");
            }
            return JsonPathFilterQuery.parse(query);
        }
        return null;
    }

    private JsonPathQuery parseGroupsQuery(String groupsQuery) {
        if (groupsQuery != null) {
            String query = groupsQuery.trim();
            if (query.length() == 0) {
                throw new IllegalArgumentException("Value of groupsClaimQuery is empty");
            }
            return JsonPathQuery.parse(query);
        }
        return null;
    }

    private String parseGroupsDelimiter(String groupsDelimiter) {
        if (groupsDelimiter != null && groupsDelimiter.length() == 0) {
            throw new IllegalArgumentException("Value of groupsClaimDelimiter is empty");
        }
        return ",";
    }

    private void validateRefreshConfig(int refreshSeconds, int expirySeconds) {
        if (refreshSeconds <= 0) {
            throw new IllegalArgumentException("refreshSeconds has to be a positive number - (refreshSeconds=" + refreshSeconds + ")");
        }
        if (expirySeconds < refreshSeconds + 60) {
            throw new IllegalArgumentException("expirySeconds has to be at least 60 seconds longer than refreshSeconds - (expirySeconds=" + expirySeconds + ", refreshSeconds=" + refreshSeconds + ")");
        }
    }

    private void setupRefreshKeysJob(ScheduledExecutorService executor, int refreshSeconds) {
        executor.scheduleAtFixedRate(() -> {
            try {
                this.fastScheduler.scheduleTask();
            }
            catch (Throwable e) {
                log.error("{}", (Object)e.getMessage(), (Object)e);
            }
        }, refreshSeconds, refreshSeconds, TimeUnit.SECONDS);
    }

    private PublicKey getPublicKey(String id) {
        return this.getKeyUnlessStale(id);
    }

    private PublicKey getKeyUnlessStale(String id) {
        if (this.lastFetchTime + (long)this.maxStaleSeconds * 1000L > System.currentTimeMillis()) {
            PublicKey result = this.cache.get(id);
            if (result == null) {
                log.warn("No public key for id: {}", (Object)id);
            }
            return result;
        }
        log.warn("The cached public key with id '{}' is expired!", (Object)id);
        return null;
    }

    private void fetchKeys() {
        long requestStartTime = System.currentTimeMillis();
        try {
            String response = HttpUtil.get(this.keysUri, this.socketFactory, this.hostnameVerifier, null, String.class, this.connectTimeout, this.readTimeout);
            this.addJwksHttpMetricSuccessTime(requestStartTime);
            Map<String, PublicKey> newCache = new HashMap();
            JWKSet jwks = JWKSet.parse((String)response);
            for (JWK jwk : jwks.getKeys()) {
                PublicKey publicKey;
                if (!this.ignoreKeyUse && !KeyUse.SIGNATURE.equals((Object)jwk.getKeyUse())) continue;
                if (jwk instanceof ECKey) {
                    publicKey = ((ECKey)jwk).toPublicKey();
                } else if (jwk instanceof RSAKey) {
                    publicKey = ((RSAKey)jwk).toPublicKey();
                } else {
                    log.warn("Unsupported JWK key type: {}", (Object)jwk.getKeyType());
                    continue;
                }
                newCache.put(jwk.getKeyID(), publicKey);
            }
            if (!this.cache.equals(newCache = Collections.unmodifiableMap(newCache))) {
                log.info("JWKS keys change detected. Keys updated.");
                this.oldCache = this.cache;
                this.cache = newCache;
            }
            this.lastFetchTime = System.currentTimeMillis();
        }
        catch (Throwable ex) {
            this.addJwksHttpMetricErrorTime(ex, requestStartTime);
            throw new ServiceException("Failed to fetch public keys needed to validate JWT signatures: " + this.keysUri, ex);
        }
    }

    @Override
    @SuppressFBWarnings(value={"BC_UNCONFIRMED_CAST_OF_RETURN_VALUE"}, justification="We tell TokenVerifier to parse AccessToken. It will return AccessToken or fail.")
    public TokenInfo validate(String token) {
        JsonNode t;
        String kid;
        SignedJWT jwt;
        try {
            jwt = SignedJWT.parse((String)token);
            kid = jwt.getHeader().getKeyID();
        }
        catch (Exception e) {
            throw new TokenValidationException("Token validation failed: Failed to parse JWT.", e).status(TokenValidationException.Status.INVALID_TOKEN);
        }
        try {
            PublicKey pub = this.getPublicKey(kid);
            if (pub == null) {
                if (this.oldCache.get(kid) != null) {
                    throw new TokenValidationException("Token validation failed: The signing key is no longer valid (kid:" + kid + ")");
                }
                try {
                    this.fastScheduler.scheduleTask();
                }
                catch (RuntimeException e) {
                    log.error("Failed to reschedule JWKS keys refresh: ", (Throwable)e);
                }
                throw new TokenValidationException("Token validation failed: Unknown signing key (kid:" + kid + ")");
            }
            if (!jwt.verify(VERIFIER_FACTORY.createJWSVerifier(jwt.getHeader(), (Key)pub))) {
                throw new TokenSignatureException("Signature check failed: Invalid token signature");
            }
            t = JSONUtil.asJson(jwt.getPayload().toJSONObject());
        }
        catch (TokenValidationException e) {
            throw e;
        }
        catch (Exception e) {
            throw new TokenValidationException("Token validation failed", e);
        }
        JsonNode exp = t.get("exp");
        if (exp == null) {
            throw new TokenValidationException("Token validation failed: Expiry not set");
        }
        long expiresMillis = (long)exp.asInt(0) * 1000L;
        if (System.currentTimeMillis() > expiresMillis) {
            throw new TokenExpiredException("Token expired at: " + expiresMillis + " (" + TimeUtil.formatIsoDateTimeUTC(expiresMillis) + " UTC)");
        }
        this.validateTokenPayload(t);
        if (this.customClaimMatcher != null && !this.customClaimMatcher.matches(t)) {
            throw new TokenValidationException("Token validation failed: Custom claim check failed");
        }
        String principal = this.extractPrincipal(t);
        Set<String> groups = this.extractGroups(t);
        return new TokenInfo(t, token, principal, groups);
    }

    private String extractPrincipal(JsonNode tokenJson) {
        String principal = null;
        if (this.principalExtractor.isConfigured()) {
            principal = this.principalExtractor.getPrincipal(tokenJson);
        }
        if (principal == null && !this.principalExtractor.isConfigured()) {
            principal = this.principalExtractor.getSub(tokenJson);
        }
        if (principal == null) {
            throw new ValidationException("Failed to extract principal - check usernameClaim, fallbackUsernameClaim configuration");
        }
        return principal;
    }

    private Set<String> extractGroups(JsonNode tokenJson) {
        if (this.groupsQuery == null) {
            return null;
        }
        JsonNode result = this.groupsQuery.apply(tokenJson);
        if (result == null) {
            return null;
        }
        List<String> groups = JSONUtil.asListOfString(result, this.groupsDelimiter != null ? this.groupsDelimiter : ",");
        Set<String> groupSet = groups.stream().map(String::trim).filter(v -> !v.isEmpty()).collect(Collectors.toSet());
        return groupSet.isEmpty() ? null : groupSet;
    }

    private void validateTokenPayload(JsonNode token) {
        if (this.issuerUri != null) {
            JsonNode iss = token.get("iss");
            if (iss == null) {
                throw new TokenValidationException("Token validation failed: Issuer not set");
            }
            String issuer = iss.asText();
            if (!this.issuerUri.equals(issuer)) {
                throw new TokenValidationException("Token validation failed: Issuer not allowed: " + issuer);
            }
        }
        if (this.checkAccessTokenType) {
            JsonNode type = token.get("typ");
            if (type == null && (type = token.get("token_type")) == null) {
                throw new TokenValidationException("Token validation failed: Token type not set ('token_type' or 'typ' claim not present)");
            }
            String value = type.asText();
            if (!"Bearer".equals(value)) {
                throw new TokenValidationException("Token validation failed: Token type not allowed: " + value);
            }
        }
        if (this.audience != null) {
            List<Object> aud;
            JsonNode audNode = token.get("aud");
            List<Object> list = aud = audNode == null ? Collections.emptyList() : JSONUtil.asListOfString(audNode);
            if (!aud.contains(this.audience)) {
                throw new TokenValidationException("Token validation failed: Expected audience not available in the token");
            }
        }
    }

    @Override
    public String getValidatorId() {
        return this.validatorId;
    }

    private void addJwksHttpMetricSuccessTime(long startTimeMs) {
        if (this.enableMetrics) {
            this.metrics.addTime(this.jwksHttpSensorKeyProducer.successKey(), System.currentTimeMillis() - startTimeMs);
        }
    }

    private void addJwksHttpMetricErrorTime(Throwable e, long startTimeMs) {
        if (this.enableMetrics) {
            this.metrics.addTime(this.jwksHttpSensorKeyProducer.errorKey(e), System.currentTimeMillis() - startTimeMs);
        }
    }
}

