/*
 * Decompiled with CFR 0.152.
 */
package com.didalgo.gpt3;

import com.didalgo.gpt3.ChatFormatDescriptor;
import com.didalgo.gpt3.CompletionType;
import com.didalgo.gpt3.Encoding;
import com.didalgo.gpt3.EncodingType;
import com.didalgo.gpt3.GPT3Tokenizer;
import java.lang.ref.SoftReference;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

public enum ModelType {
    GPT_4_TURBO("gpt-4-turbo-preview", EncodingType.CL100K_BASE, 128000, CompletionType.CHAT),
    GPT_4("gpt-4", EncodingType.CL100K_BASE, 8192, CompletionType.CHAT),
    GPT_4_32K("gpt-4-32k", EncodingType.CL100K_BASE, 32768, CompletionType.CHAT),
    GPT_3_5_TURBO("gpt-3.5-turbo", EncodingType.CL100K_BASE, 16384, CompletionType.CHAT),
    GPT_3_5_TURBO_LEGACY("gpt-3.5-turbo", EncodingType.CL100K_BASE, 4096, CompletionType.CHAT),
    GPT_3_5_TURBO_16K("gpt-3.5-turbo-16k", EncodingType.CL100K_BASE, 16384, CompletionType.CHAT),
    GPT_3_5_TURBO_INSTRUCT("gpt-3.5-turbo-instruct", EncodingType.CL100K_BASE, 4097, CompletionType.TEXT),
    TEXT_DAVINCI_003("text-davinci-003", EncodingType.P50K_BASE, 4097, CompletionType.TEXT),
    TEXT_DAVINCI_002("text-davinci-002", EncodingType.P50K_BASE, 4097, CompletionType.TEXT),
    TEXT_DAVINCI_001("text-davinci-001", EncodingType.R50K_BASE, 2049, CompletionType.TEXT),
    TEXT_CURIE_001("text-curie-001", EncodingType.R50K_BASE, 2049, CompletionType.TEXT),
    TEXT_BABBAGE_001("text-babbage-001", EncodingType.R50K_BASE, 2049, CompletionType.TEXT),
    TEXT_ADA_001("text-ada-001", EncodingType.R50K_BASE, 2049, CompletionType.TEXT),
    DAVINCI("davinci", EncodingType.R50K_BASE, 2049, CompletionType.TEXT),
    CURIE("curie", EncodingType.R50K_BASE, 2049, CompletionType.TEXT),
    BABBAGE("babbage", EncodingType.R50K_BASE, 2049, CompletionType.TEXT),
    ADA("ada", EncodingType.R50K_BASE, 2049, CompletionType.TEXT),
    CODE_DAVINCI_002("code-davinci-002", EncodingType.P50K_BASE, 8001, CompletionType.TEXT),
    TEXT_DAVINCI_EDIT_001("text-davinci-edit-001", EncodingType.P50K_EDIT, 2049, CompletionType.TEXT),
    CODE_DAVINCI_EDIT_001("code-davinci-edit-001", EncodingType.P50K_EDIT, 2049, CompletionType.TEXT),
    TEXT_EMBEDDING_ADA_002("text-embedding-ada-002", EncodingType.CL100K_BASE, 8192, CompletionType.TEXT);

    private final String modelName;
    private final EncodingType encodingType;
    private final int maxTokens;
    private final CompletionType completionType;
    private static Map<String, ModelType> specialVariants;

    private ModelType(String modelName, EncodingType encodingType, int maxTokens, CompletionType completionType) {
        this.modelName = modelName;
        this.encodingType = encodingType;
        this.maxTokens = maxTokens;
        this.completionType = completionType;
    }

    public String modelName() {
        return this.modelName;
    }

    public EncodingType encodingType() {
        return this.encodingType;
    }

    public int maxTokens() {
        return this.maxTokens;
    }

    public CompletionType completionType() {
        return this.completionType;
    }

    public static Optional<ModelType> forModel(String modelName) throws IllegalArgumentException {
        Optional<ModelType> modelType = ModelType.forModelExact(modelName);
        if (modelType.isPresent()) {
            return modelType;
        }
        if (modelName.matches(".*-\\d{4}$") && (modelType = ModelType.forModelExact(modelName = modelName.substring(0, modelName.length() - 5))).isPresent()) {
            return modelType;
        }
        throw new IllegalArgumentException("Model `" + modelName + "` not found");
    }

    private static Optional<ModelType> forModelExact(String modelName) {
        if (specialVariants.containsKey(modelName)) {
            return Optional.of(specialVariants.get(modelName));
        }
        for (ModelType modelType : ModelType.values()) {
            if (!modelType.modelName().equals(modelName)) continue;
            return Optional.of(modelType);
        }
        return Optional.empty();
    }

    public Encoding getEncoding() {
        return Encoding.forName(this.encodingType().encodingName());
    }

    public GPT3Tokenizer getTokenizer() {
        return Cache.getTokenizer(this);
    }

    public ChatFormatDescriptor getChatFormatDescriptor() {
        return ChatFormatDescriptor.forModel(this.modelName());
    }

    static {
        specialVariants = new HashMap<String, ModelType>();
        specialVariants.put("gpt-3.5-turbo-0301", GPT_3_5_TURBO_LEGACY);
        specialVariants.put("gpt-3.5-turbo-0613", GPT_3_5_TURBO_LEGACY);
        specialVariants.put("gpt-4-turbo-preview", GPT_4_TURBO);
        specialVariants.put("gpt-4-1106-preview", GPT_4_TURBO);
        specialVariants.put("gpt-4-0125-preview", GPT_4_TURBO);
    }

    private static final class Cache {
        private static final Map<ModelType, SoftReference<GPT3Tokenizer>> gptTokenizersCache = Collections.synchronizedMap(new EnumMap(ModelType.class));

        private Cache() {
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private static GPT3Tokenizer getTokenizer(ModelType model) {
            GPT3Tokenizer tokenizer;
            SoftReference<GPT3Tokenizer> ref = gptTokenizersCache.get((Object)model);
            if (ref == null || (tokenizer = ref.get()) == null) {
                Map<ModelType, SoftReference<GPT3Tokenizer>> map = gptTokenizersCache;
                synchronized (map) {
                    tokenizer = new GPT3Tokenizer(model.getEncoding());
                    gptTokenizersCache.put(model, new SoftReference<GPT3Tokenizer>(tokenizer));
                }
            }
            return tokenizer;
        }
    }
}

