/*
 * Decompiled with CFR 0.152.
 */
package com.didalgo.gpt3;

import com.didalgo.gpt3.ByteSequence;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Pattern;

public interface Encoding {
    public static final String ENDOFTEXT = "<|endoftext|>";
    public static final String FIM_PREFIX = "<|fim_prefix|>";
    public static final String FIM_MIDDLE = "<|fim_middle|>";
    public static final String FIM_SUFFIX = "<|fim_suffix|>";
    public static final String ENDOFPROMPT = "<|endofprompt|>";
    public static final Encoding CL100K_BASE = new Of("cl100k_base.tiktoken", new HashMap<ByteSequence, Integer>(), Map.of("<|endoftext|>", 100257, "<|fim_prefix|>", 100258, "<|fim_middle|>", 100259, "<|fim_suffix|>", 100260, "<|endofprompt|>", 100276), Pattern.compile("(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", 256));
    public static final Encoding P50K_BASE = new Of("p50k_base.tiktoken", new HashMap<ByteSequence, Integer>(), Map.of("<|endoftext|>", 50256), Pattern.compile("'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+", 256));
    public static final Encoding P50K_EDIT = new Of("p50k_base.tiktoken", new HashMap<ByteSequence, Integer>(), Map.of("<|endoftext|>", 50256, "<|fim_prefix|>", 50281, "<|fim_middle|>", 50282, "<|fim_suffix|>", 50283), Pattern.compile("'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+", 256));
    public static final Encoding R50K_BASE = new Of("r50k_base.tiktoken", new HashMap<ByteSequence, Integer>(), Map.of("<|endoftext|>", 50256), Pattern.compile("'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+", 256));

    public Map<ByteSequence, Integer> mergeableRanks();

    public Map<String, Integer> specialTokens();

    public Pattern pattern();

    public static Encoding forName(String encodingName) {
        return switch (encodingName.toLowerCase()) {
            case "cl100k_base" -> CL100K_BASE;
            case "p50k_base" -> P50K_BASE;
            case "p50k_edit" -> P50K_EDIT;
            case "r50k_base" -> R50K_BASE;
            default -> throw new IllegalArgumentException("Unknown encoding: " + encodingName);
        };
    }

    public static Encoding forModel(String modelName) {
        String encodingName = Lookup.modelToEncoding.get(modelName);
        if (encodingName == null) {
            encodingName = Lookup.modelPrefixToEncoding.keySet().stream().filter(modelName::startsWith).findFirst().map(Lookup.modelPrefixToEncoding::get).orElseThrow(() -> new IllegalArgumentException("Unknown model name: " + modelName));
        }
        return Encoding.forName(encodingName);
    }

    public static final class Lookup {
        private static final Map<String, String> modelPrefixToEncoding;
        private static final Map<String, String> modelToEncoding;

        public static Map<ByteSequence, Integer> loadTiktokenBase(String filename, Map<ByteSequence, Integer> resultMap) {
            HashMap<ByteSequence, Integer> hashMap;
            block8: {
                InputStream in = Lookup.class.getResourceAsStream(filename);
                try {
                    HashMap<ByteSequence, Integer> result = resultMap == null ? new HashMap<ByteSequence, Integer>() : resultMap;
                    new BufferedReader(new InputStreamReader(in, StandardCharsets.US_ASCII)).lines().filter(line -> !line.isEmpty()).forEach(line -> {
                        int spaceIdx = line.indexOf(32);
                        if (spaceIdx > 0) {
                            ByteSequence key = ByteSequence.of(Base64.getDecoder().decode(line.substring(0, spaceIdx)));
                            int value = Integer.parseInt(line.substring(spaceIdx + 1));
                            result.put(key, value);
                        }
                    });
                    hashMap = result;
                    if (in == null) break block8;
                }
                catch (Throwable throwable) {
                    try {
                        if (in != null) {
                            try {
                                in.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (IOException e) {
                        throw new UncheckedIOException(e);
                    }
                }
                in.close();
            }
            return hashMap;
        }

        static {
            HashMap<String, String> mp2e = new HashMap<String, String>();
            mp2e.put("gpt-4-", "cl100k_base");
            mp2e.put("gpt-3.5-turbo-", "cl100k_base");
            modelPrefixToEncoding = mp2e;
            HashMap<String, String> m2e = new HashMap<String, String>();
            m2e.put("gpt-4", "cl100k_base");
            m2e.put("gpt-3.5-turbo", "cl100k_base");
            m2e.put("text-davinci-003", "p50k_base");
            m2e.put("text-davinci-002", "p50k_base");
            m2e.put("text-davinci-001", "r50k_base");
            m2e.put("text-curie-001", "r50k_base");
            m2e.put("text-babbage-001", "r50k_base");
            m2e.put("text-ada-001", "r50k_base");
            m2e.put("davinci", "r50k_base");
            m2e.put("curie", "r50k_base");
            m2e.put("babbage", "r50k_base");
            m2e.put("ada", "r50k_base");
            m2e.put("code-davinci-002", "p50k_base");
            m2e.put("code-davinci-001", "p50k_base");
            m2e.put("code-cushman-002", "p50k_base");
            m2e.put("code-cushman-001", "p50k_base");
            m2e.put("davinci-codex", "p50k_base");
            m2e.put("cushman-codex", "p50k_base");
            m2e.put("text-davinci-edit-001", "p50k_edit");
            m2e.put("code-davinci-edit-001", "p50k_edit");
            m2e.put("text-embedding-ada-002", "cl100k_base");
            m2e.put("text-similarity-davinci-001", "r50k_base");
            m2e.put("text-similarity-curie-001", "r50k_base");
            m2e.put("text-similarity-babbage-001", "r50k_base");
            m2e.put("text-similarity-ada-001", "r50k_base");
            m2e.put("text-search-davinci-doc-001", "r50k_base");
            m2e.put("text-search-curie-doc-001", "r50k_base");
            m2e.put("text-search-babbage-doc-001", "r50k_base");
            m2e.put("text-search-ada-doc-001", "r50k_base");
            m2e.put("code-search-babbage-code-001", "r50k_base");
            m2e.put("code-search-ada-code-001", "r50k_base");
            modelToEncoding = m2e;
        }
    }

    public record Of(String tiktokenFilename, Map<ByteSequence, Integer> mergeableRanks, Map<String, Integer> specialTokens, Pattern pattern) implements Encoding
    {
        private final Map<ByteSequence, Integer> mergeableRanks;

        public Of {
            specialTokens = Collections.unmodifiableMap(new HashMap<String, Integer>(specialTokens));
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public Map<ByteSequence, Integer> mergeableRanks() {
            if (this.mergeableRanks.isEmpty()) {
                Map<ByteSequence, Integer> map = this.mergeableRanks;
                synchronized (map) {
                    if (this.mergeableRanks.isEmpty()) {
                        Lookup.loadTiktokenBase(this.tiktokenFilename, this.mergeableRanks);
                    }
                }
            }
            return Collections.unmodifiableMap(this.mergeableRanks);
        }
    }
}

