/*
 * Decompiled with CFR 0.152.
 */
package io.trino.server.security.jwt;

import com.google.common.base.CharMatcher;
import com.google.common.io.Files;
import com.google.inject.Inject;
import io.airlift.security.pem.PemReader;
import io.jsonwebtoken.Header;
import io.jsonwebtoken.JweHeader;
import io.jsonwebtoken.JwsHeader;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.Locator;
import io.jsonwebtoken.UnsupportedJwtException;
import io.jsonwebtoken.security.Keys;
import io.jsonwebtoken.security.MacAlgorithm;
import io.jsonwebtoken.security.SecureDigestAlgorithm;
import io.jsonwebtoken.security.SecurityException;
import io.trino.server.security.jwt.JwtAuthenticatorConfig;
import java.io.File;
import java.io.IOException;
import java.lang.runtime.SwitchBootstraps;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.security.Key;
import java.security.PublicKey;
import java.util.Base64;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import javax.crypto.SecretKey;

public class FileSigningKeyLocator
implements Locator<Key> {
    private static final String DEFAULT_KEY = "default-key";
    private static final CharMatcher INVALID_KID_CHARS = CharMatcher.inRange((char)'a', (char)'z').or(CharMatcher.inRange((char)'A', (char)'Z')).or(CharMatcher.inRange((char)'0', (char)'9')).or(CharMatcher.anyOf((CharSequence)"_-")).negate();
    private static final String KEY_ID_VARIABLE = "${KID}";
    private final String keyFile;
    private final LoadedKey staticKey;
    private final ConcurrentMap<String, LoadedKey> keys = new ConcurrentHashMap<String, LoadedKey>();

    @Inject
    public FileSigningKeyLocator(JwtAuthenticatorConfig config) {
        this(config.getKeyFile());
    }

    public FileSigningKeyLocator(String keyFile) {
        this.keyFile = Objects.requireNonNull(keyFile, "keyFile is null");
        this.staticKey = keyFile.contains(KEY_ID_VARIABLE) ? null : FileSigningKeyLocator.loadKeyFile(new File(keyFile));
    }

    public Key locate(Header header) {
        Header header2 = header;
        Objects.requireNonNull(header2);
        Header header3 = header2;
        int n = 0;
        return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{JwsHeader.class, JweHeader.class}, (Object)header3, n)) {
            case 0 -> {
                JwsHeader jwsHeader = (JwsHeader)header3;
                yield this.getKey(jwsHeader.getKeyId(), jwsHeader.getAlgorithm());
            }
            case 1 -> {
                JweHeader jweHeader = (JweHeader)header3;
                yield this.getKey(jweHeader.getKeyId(), jweHeader.getAlgorithm());
            }
            default -> throw new UnsupportedJwtException("Cannot locate key for header: %s".formatted(header.getType()));
        };
    }

    private Key getKey(String keyId, String algorithm) {
        SecureDigestAlgorithm secureDigestAlgorithm = (SecureDigestAlgorithm)Jwts.SIG.get().forKey((Object)algorithm);
        if (this.staticKey != null) {
            return this.staticKey.getKey(secureDigestAlgorithm);
        }
        LoadedKey key = this.keys.computeIfAbsent(FileSigningKeyLocator.getKeyId(keyId), this::loadKey);
        return key.getKey(secureDigestAlgorithm);
    }

    private static String getKeyId(String keyId) {
        if (keyId == null) {
            return DEFAULT_KEY;
        }
        keyId = INVALID_KID_CHARS.replaceFrom((CharSequence)keyId, '_');
        return keyId;
    }

    private LoadedKey loadKey(String keyId) {
        return FileSigningKeyLocator.loadKeyFile(new File(this.keyFile.replace(KEY_ID_VARIABLE, keyId)));
    }

    private static LoadedKey loadKeyFile(File file) {
        String data;
        if (!file.canRead()) {
            throw new SecurityException("Unknown signing key ID");
        }
        try {
            data = Files.asCharSource((File)file, (Charset)StandardCharsets.US_ASCII).read();
        }
        catch (IOException e) {
            throw new SecurityException("Unable to read signing key", (Throwable)e);
        }
        if (PemReader.isPem((String)data)) {
            try {
                return new LoadedKey(PemReader.loadPublicKey((String)data));
            }
            catch (RuntimeException | GeneralSecurityException e) {
                throw new SecurityException("Unable to decode PEM signing key id", (Throwable)e);
            }
        }
        try {
            SecretKey hmacKey = Keys.hmacShaKeyFor((byte[])Base64.getMimeDecoder().decode(data.getBytes(StandardCharsets.US_ASCII)));
            return new LoadedKey(hmacKey);
        }
        catch (RuntimeException e) {
            throw new SecurityException("Unable to decode HMAC signing key", (Throwable)e);
        }
    }

    private static class LoadedKey {
        private final PublicKey publicKey;
        private final SecretKey secretKey;

        public LoadedKey(PublicKey publicKey) {
            this.publicKey = Objects.requireNonNull(publicKey, "publicKey is null");
            this.secretKey = null;
        }

        public LoadedKey(SecretKey secretKey) {
            this.secretKey = Objects.requireNonNull(secretKey, "secretKey is null");
            this.publicKey = null;
        }

        public Key getKey(SecureDigestAlgorithm<?, ?> algorithm) {
            if (algorithm instanceof MacAlgorithm) {
                if (this.secretKey == null) {
                    throw new UnsupportedJwtException(String.format("JWT is signed with %s, but no HMAC key is configured", algorithm));
                }
                return this.secretKey;
            }
            if (this.publicKey == null) {
                throw new UnsupportedJwtException(String.format("JWT is signed with %s, but no key is configured", algorithm));
            }
            return this.publicKey;
        }
    }
}

