/*
 * Decompiled with CFR 0.152.
 */
package io.micronaut.http.ssl;

import io.micronaut.core.annotation.Nullable;
import io.micronaut.core.util.ArrayUtils;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.security.Key;
import java.security.KeyFactory;
import java.security.cert.CertificateFactory;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collection;
import java.util.List;
import javax.crypto.Cipher;
import javax.crypto.EncryptedPrivateKeyInfo;
import javax.crypto.SecretKey;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.PBEKeySpec;

record PemParser(@Nullable String provider, @Nullable String password) {
    private static final String DASHES = "-----";
    private static final String START = "-----BEGIN ";
    private static final String END = "-----END ";
    private static final String OID_RSA = "1.2.840.113549.1.1.1";
    private static final String OID_EC = "1.2.840.10045.2.1";

    List<Object> loadPem(byte[] pem) throws GeneralSecurityException, IllegalArgumentException, NotPemException {
        String s;
        try {
            s = new String(pem, StandardCharsets.UTF_8);
        }
        catch (Exception e) {
            throw new NotPemException("Invalid UTF-8");
        }
        return this.loadPem(s);
    }

    List<Object> loadPem(String pem) throws GeneralSecurityException, IllegalArgumentException, NotPemException {
        ArrayList<Object> list = new ArrayList<Object>();
        int i = 0;
        while (i < pem.length()) {
            if (Character.isWhitespace(pem.charAt(i))) {
                ++i;
                continue;
            }
            if (!pem.startsWith(START, i)) {
                if (list.isEmpty()) {
                    throw new NotPemException("Missing start tag");
                }
                throw PemParser.invalidPem(false);
            }
            int labelEnd = pem.indexOf(DASHES, i += START.length());
            if (labelEnd == -1) {
                throw PemParser.invalidPem(list.isEmpty());
            }
            String label = pem.substring(i, labelEnd);
            i = labelEnd + DASHES.length();
            String trailer = END + label + DASHES;
            int sectionEnd = pem.indexOf(trailer, i);
            if (sectionEnd == -1) {
                throw PemParser.invalidPem(list.isEmpty());
            }
            Decoder decoder = this.getDecoder(label);
            String contentString = pem.substring(i, sectionEnd).replace("\r", "").replace("\n", "");
            i = sectionEnd + trailer.length();
            byte[] content = Base64.getDecoder().decode(contentString);
            list.addAll(decoder.decode(content));
        }
        if (list.isEmpty()) {
            throw new IllegalArgumentException("PEM file empty");
        }
        return list;
    }

    private static IllegalArgumentException invalidPem(boolean first) throws NotPemException {
        if (first) {
            throw new NotPemException("Invalid PEM");
        }
        return new IllegalArgumentException("Invalid PEM");
    }

    private Decoder getDecoder(String label) {
        return switch (label) {
            case "CERTIFICATE", "X509 CERTIFICATE" -> new CertificateDecoder();
            case "ENCRYPTED PRIVATE KEY" -> new Pkcs8EncryptedPrivateKey();
            case "PRIVATE KEY" -> new Pkcs8PrivateKey();
            case "RSA PRIVATE KEY" -> new Pkcs1PrivateKey(false);
            case "EC PRIVATE KEY" -> new Pkcs1PrivateKey(true);
            default -> throw new IllegalArgumentException("Unsupported PEM label: " + label);
        };
    }

    public static final class NotPemException
    extends Exception {
        private NotPemException(String message) {
            super(message);
        }
    }

    /*
     * Uses 'sealed' constructs - enablewith --sealed true
     */
    private static interface Decoder {
        public Collection<?> decode(byte[] var1) throws GeneralSecurityException;
    }

    private final class CertificateDecoder
    implements Decoder {
        private CertificateDecoder() {
        }

        @Override
        public Collection<?> decode(byte[] der) throws GeneralSecurityException {
            CertificateFactory factory = PemParser.this.provider == null ? CertificateFactory.getInstance("X.509") : CertificateFactory.getInstance("X.509", PemParser.this.provider);
            return factory.generateCertificates(new ByteArrayInputStream(der));
        }
    }

    private final class Pkcs8EncryptedPrivateKey
    implements Decoder {
        private Pkcs8EncryptedPrivateKey() {
        }

        @Override
        public Collection<?> decode(byte[] der) throws GeneralSecurityException {
            SecretKeyFactory skf;
            EncryptedPrivateKeyInfo keyInfo;
            try {
                keyInfo = new EncryptedPrivateKeyInfo(der);
            }
            catch (IOException e) {
                throw new GeneralSecurityException("Invalid DER", e);
            }
            String cipherAlg = keyInfo.getAlgName();
            if (cipherAlg.equals("PBES2")) {
                cipherAlg = keyInfo.getAlgParameters().toString();
            }
            SecretKeyFactory secretKeyFactory = skf = PemParser.this.provider == null ? SecretKeyFactory.getInstance(cipherAlg) : SecretKeyFactory.getInstance(cipherAlg, PemParser.this.provider);
            if (PemParser.this.password == null) {
                throw new IllegalArgumentException("Encrypted private key found but no password given");
            }
            SecretKey sk = skf.generateSecret(new PBEKeySpec(PemParser.this.password.toCharArray()));
            Cipher cipher = Cipher.getInstance(cipherAlg);
            cipher.init(2, (Key)sk, keyInfo.getAlgParameters());
            PKCS8EncodedKeySpec keySpec = keyInfo.getKeySpec(cipher);
            String keyAlg = keySpec.getAlgorithm();
            KeyFactory factory = PemParser.this.provider == null ? KeyFactory.getInstance(keyAlg) : KeyFactory.getInstance(keyAlg, PemParser.this.provider);
            return List.of(factory.generatePrivate(keySpec));
        }
    }

    private final class Pkcs8PrivateKey
    implements Decoder {
        private Pkcs8PrivateKey() {
        }

        @Override
        public Collection<?> decode(byte[] der) throws GeneralSecurityException {
            String algOid;
            DerInput outer = new DerInput(der);
            DerInput privateKeyInfo = outer.readSequence();
            privateKeyInfo.expect(2);
            privateKeyInfo.expect(1);
            privateKeyInfo.expect(0);
            DerInput privateKeyAlgorithm = privateKeyInfo.readSequence();
            String alg = switch (algOid = privateKeyAlgorithm.readOid()) {
                case PemParser.OID_RSA -> "RSA";
                case PemParser.OID_EC -> "EC";
                case "1.3.101.112" -> "Ed25519";
                case "1.3.101.113" -> "Ed448";
                case "2.16.840.1.101.3.4.3.17", "2.16.840.1.101.3.4.3.18", "2.16.840.1.101.3.4.3.19" -> "ML-DSA";
                case "2.16.840.1.101.3.4.4.1", "2.16.840.1.101.3.4.4.2", "2.16.840.1.101.3.4.4.3" -> "ML-KEM";
                default -> throw new IllegalArgumentException("Unrecognized PKCS#8 key algorithm " + algOid);
            };
            KeyFactory factory = PemParser.this.provider == null ? KeyFactory.getInstance(alg) : KeyFactory.getInstance(alg, PemParser.this.provider);
            return List.of(factory.generatePrivate(new PKCS8EncodedKeySpec(der)));
        }
    }

    private final class Pkcs1PrivateKey
    implements Decoder {
        private final boolean ec;

        Pkcs1PrivateKey(boolean ec) {
            this.ec = ec;
        }

        @Override
        public Collection<?> decode(byte[] der) throws GeneralSecurityException {
            DerOutput output = new DerOutput();
            try (DerOutput.Value privateKeyInfo = output.writeValue(48);){
                try (DerOutput.Value version = output.writeValue(2);){
                    output.write(0);
                }
                try (DerOutput.Value privateKeyAlgorithm = output.writeValue(48);){
                    if (this.ec) {
                        DerInput parameters = this.extractCurveParams(der);
                        output.writeOid(PemParser.OID_EC);
                        output.write(parameters.data, parameters.i, parameters.limit - parameters.i);
                    } else {
                        output.writeOid(PemParser.OID_RSA);
                        output.writeValue(5).close();
                    }
                }
                try (DerOutput.Value privateKey = output.writeValue(4);){
                    output.write(der, 0, der.length);
                }
            }
            byte[] pkcs8 = output.finish();
            String algorithm = this.ec ? "EC" : "RSA";
            KeyFactory factory = PemParser.this.provider == null ? KeyFactory.getInstance(algorithm) : KeyFactory.getInstance(algorithm, PemParser.this.provider);
            return List.of(factory.generatePrivate(new PKCS8EncodedKeySpec(pkcs8)));
        }

        private DerInput extractCurveParams(byte[] der) {
            DerInput input = new DerInput(der);
            DerInput ecPrivateKey = input.readSequence();
            ecPrivateKey.expect(2);
            ecPrivateKey.expect(1);
            ecPrivateKey.expect(1);
            ecPrivateKey.readValue(4);
            DerInput parameters = null;
            while (ecPrivateKey.i < ecPrivateKey.limit) {
                int tag = ecPrivateKey.peekTag();
                DerInput value = ecPrivateKey.readValue(tag);
                if (tag != 160) continue;
                parameters = value;
            }
            if (parameters == null) {
                throw new IllegalArgumentException("Curve parameters not found for EC private key");
            }
            return parameters;
        }
    }

    private static final class DerOutput {
        private byte[] out = ArrayUtils.EMPTY_BYTE_ARRAY;
        private int i;

        private DerOutput() {
        }

        private void ensureCapacity(int n) {
            while (this.i + n > this.out.length) {
                this.out = Arrays.copyOf(this.out, this.out.length == 0 ? 16 : this.out.length * 2);
            }
        }

        void write(int b) {
            this.ensureCapacity(1);
            this.out[this.i++] = (byte)b;
        }

        void write(byte[] arr, int start, int len) {
            this.ensureCapacity(len);
            System.arraycopy(arr, start, this.out, this.i, len);
            this.i += len;
        }

        byte[] finish() {
            return Arrays.copyOf(this.out, this.i);
        }

        private static int varIntLength(int value) {
            if (value < 128) {
                return 1;
            }
            if (value < 16384) {
                return 2;
            }
            if (value < 0x200000) {
                return 3;
            }
            if (value < 0x10000000) {
                return 4;
            }
            return 5;
        }

        private void writeVarInt(int value) {
            int len = DerOutput.varIntLength(value);
            for (int i = len - 1; i >= 0; --i) {
                this.write(value >> i * 7 & 0x7F | (i == 0 ? 0 : 128));
            }
        }

        void writeOid(String oid) {
            try (Value ignored = this.writeValue(6);){
                String[] parts = oid.split("\\.");
                for (int j = 0; j < parts.length; ++j) {
                    int value = Integer.parseInt(parts[j]);
                    if (j == 0) {
                        int next = Integer.parseInt(parts[++j]);
                        this.writeVarInt(value * 40 + next);
                        continue;
                    }
                    this.writeVarInt(value);
                }
            }
        }

        Value writeValue(int tag) {
            this.write(tag);
            return new Value();
        }

        final class Value
        implements AutoCloseable {
            private final int lengthOffset;

            private Value() {
                this.lengthOffset = DerOutput.this.i;
                DerOutput.this.write(0);
            }

            @Override
            public void close() {
                int length = DerOutput.this.i - this.lengthOffset - 1;
                if (length < 128) {
                    DerOutput.this.out[this.lengthOffset] = (byte)length;
                    return;
                }
                int lengthLength = length < 256 ? 1 : (length < 65536 ? 2 : (length < 0x1000000 ? 3 : 4));
                DerOutput.this.out[this.lengthOffset] = (byte)(0x80 | lengthLength);
                for (int i = 0; i < lengthLength; ++i) {
                    DerOutput.this.write(0);
                }
                System.arraycopy(DerOutput.this.out, this.lengthOffset + 1, DerOutput.this.out, this.lengthOffset + 1 + lengthLength, length);
                int mark = DerOutput.this.i;
                DerOutput.this.i = this.lengthOffset + 1;
                for (int i = lengthLength - 1; i >= 0; --i) {
                    DerOutput.this.write(length >> i * 8 & 0xFF);
                }
                DerOutput.this.i = mark;
            }
        }
    }

    private static final class DerInput {
        final byte[] data;
        final int limit;
        int i;

        DerInput(byte[] data) {
            this(data, 0, data.length);
        }

        private DerInput(byte[] data, int start, int limit) {
            this.data = data;
            this.i = start;
            this.limit = limit;
        }

        byte read() {
            if (this.i >= this.limit) {
                throw DerInput.invalidDer();
            }
            return this.data[this.i++];
        }

        void expect(int value) {
            if ((this.read() & 0xFF) != value) {
                throw DerInput.invalidDer();
            }
        }

        private int readLength() {
            int length;
            block2: {
                byte b = this.read();
                if (b >= 0) {
                    return b;
                }
                b = (byte)(b & 0x7F);
                length = 0;
                do {
                    byte by = b;
                    b = (byte)(b - 1);
                    if (by <= 0) break block2;
                    length <<= 8;
                } while ((length |= this.read() & 0xFF) >= 0 && length <= this.limit - this.i);
                throw DerInput.invalidDer();
            }
            return length;
        }

        int peekTag() {
            int tag = this.read() & 0xFF;
            --this.i;
            return tag;
        }

        DerInput readValue(int tag) {
            this.expect(tag);
            int n = this.readLength();
            int end = this.i + n;
            DerInput sequence = new DerInput(this.data, this.i, end);
            this.i = end;
            return sequence;
        }

        DerInput readSequence() {
            return this.readValue(48);
        }

        String readOid() {
            DerInput helper = this.readValue(6);
            StringBuilder builder = new StringBuilder();
            while (helper.i < helper.limit) {
                byte b;
                long value = 0L;
                do {
                    b = helper.read();
                    value <<= 7;
                    value |= (long)(b & 0x7F);
                } while (b < 0);
                if (builder.isEmpty()) {
                    if (value >= 80L) {
                        builder.append("2.").append(value - 80L);
                        continue;
                    }
                    builder.append(value / 40L).append('.').append(value % 40L);
                    continue;
                }
                builder.append('.').append(value);
            }
            return builder.toString();
        }

        private static RuntimeException invalidDer() {
            return new IllegalArgumentException("Invalid PKCS#8");
        }
    }
}

