/*
 * Decompiled with CFR 0.152.
 */
package io.strimzi.kafka.oauth.validator;

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import io.strimzi.kafka.oauth.common.HttpUtil;
import io.strimzi.kafka.oauth.common.TimeUtil;
import io.strimzi.kafka.oauth.common.TokenInfo;
import io.strimzi.kafka.oauth.validator.TokenExpiredException;
import io.strimzi.kafka.oauth.validator.TokenSignatureException;
import io.strimzi.kafka.oauth.validator.TokenValidationException;
import io.strimzi.kafka.oauth.validator.TokenValidator;
import java.net.URI;
import java.net.URISyntaxException;
import java.security.PublicKey;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLSocketFactory;
import org.apache.kafka.common.utils.Time;
import org.keycloak.TokenVerifier;
import org.keycloak.exceptions.TokenSignatureInvalidException;
import org.keycloak.jose.jwk.JSONWebKeySet;
import org.keycloak.jose.jwk.JWK;
import org.keycloak.representations.AccessToken;
import org.keycloak.util.JWKSUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JWTSignatureValidator
implements TokenValidator {
    private static final Logger log = LoggerFactory.getLogger(JWTSignatureValidator.class);
    private final ScheduledExecutorService scheduler;
    private final URI keysUri;
    private final String issuerUri;
    private final int maxStaleSeconds;
    private final boolean defaultChecks;
    private final String audience;
    private final SSLSocketFactory socketFactory;
    private final HostnameVerifier hostnameVerifier;
    private long lastFetchTime;
    private Map<String, PublicKey> cache = new ConcurrentHashMap<String, PublicKey>();

    public JWTSignatureValidator(String keysEndpointUri, SSLSocketFactory socketFactory, HostnameVerifier verifier, String validIssuerUri, int refreshSeconds, int expirySeconds, boolean defaultChecks, String audience) {
        if (keysEndpointUri == null) {
            throw new IllegalArgumentException("keysEndpointUri == null");
        }
        try {
            this.keysUri = new URI(keysEndpointUri);
        }
        catch (URISyntaxException e) {
            throw new IllegalArgumentException("Invalid keysEndpointUri: " + keysEndpointUri, e);
        }
        if (socketFactory != null && !"https".equals(this.keysUri.getScheme())) {
            throw new IllegalArgumentException("SSL socket factory set but keysEndpointUri not 'https'");
        }
        this.socketFactory = socketFactory;
        if (verifier != null && !"https".equals(this.keysUri.getScheme())) {
            throw new IllegalArgumentException("Certificate hostname verifier set but keysEndpointUri not 'https'");
        }
        this.hostnameVerifier = verifier;
        if (validIssuerUri == null) {
            throw new IllegalArgumentException("validIssuerUri == null");
        }
        this.issuerUri = validIssuerUri;
        if (expirySeconds < refreshSeconds + 60) {
            throw new IllegalArgumentException("expirySeconds has to be at least 60 seconds longer than refreshSeconds");
        }
        this.maxStaleSeconds = expirySeconds;
        this.defaultChecks = defaultChecks;
        this.audience = audience;
        this.fetchKeys();
        this.scheduler = Executors.newSingleThreadScheduledExecutor(new DaemonThreadFactory());
        this.scheduler.scheduleAtFixedRate(() -> this.fetchKeys(), refreshSeconds, refreshSeconds, TimeUnit.SECONDS);
    }

    private PublicKey getPublicKey(String id) {
        return this.getKeyUnlessStale(id);
    }

    private PublicKey getKeyUnlessStale(String id) {
        if (this.lastFetchTime + (long)this.maxStaleSeconds * 1000L > System.currentTimeMillis()) {
            PublicKey result = this.cache.get(id);
            if (result == null) {
                log.warn("No public key for id: " + id);
            }
            return result;
        }
        log.warn("The cached public key with id '" + id + "' is expired!");
        return null;
    }

    private void fetchKeys() {
        try {
            JSONWebKeySet jwks = HttpUtil.get(this.keysUri, this.socketFactory, this.hostnameVerifier, null, JSONWebKeySet.class);
            this.cache = JWKSUtils.getKeysForUse((JSONWebKeySet)jwks, (JWK.Use)JWK.Use.SIG);
            this.lastFetchTime = System.currentTimeMillis();
        }
        catch (Exception ex) {
            throw new RuntimeException("Failed to fetch public keys needed to validate JWT signatures: " + this.keysUri, ex);
        }
    }

    @Override
    @SuppressFBWarnings(value={"BC_UNCONFIRMED_CAST_OF_RETURN_VALUE"}, justification="We tell TokenVerifier to parse AccessToken. It will return AccessToken or fail.")
    public TokenInfo validate(String token) {
        AccessToken t;
        TokenVerifier tokenVerifier = TokenVerifier.create((String)token, AccessToken.class);
        if (this.defaultChecks) {
            tokenVerifier.withDefaultChecks().realmUrl(this.issuerUri);
        }
        String kid = null;
        try {
            kid = tokenVerifier.getHeader().getKeyId();
        }
        catch (Exception e) {
            throw new TokenValidationException("Token signature validation failed: " + token, e).status(TokenValidationException.Status.INVALID_TOKEN);
        }
        tokenVerifier.publicKey(this.getPublicKey(kid));
        if (this.audience != null) {
            tokenVerifier.audience(this.audience);
        }
        try {
            tokenVerifier.verify();
            t = (AccessToken)tokenVerifier.getToken();
        }
        catch (TokenSignatureInvalidException e) {
            throw new TokenSignatureException("Signature check failed:", e);
        }
        catch (Exception e) {
            throw new TokenValidationException("Token validation failed:", e);
        }
        long expiresMillis = (long)t.getExpiration() * 1000L;
        if (Time.SYSTEM.milliseconds() > expiresMillis) {
            throw new TokenExpiredException("Token expired at: " + expiresMillis + " (" + TimeUtil.formatIsoDateTimeUTC(expiresMillis) + ")");
        }
        return new TokenInfo(t, token);
    }

    static class DaemonThreadFactory
    implements ThreadFactory {
        DaemonThreadFactory() {
        }

        @Override
        public Thread newThread(Runnable r) {
            Thread t = Executors.defaultThreadFactory().newThread(r);
            t.setDaemon(true);
            return t;
        }
    }
}

