/*
 * Decompiled with CFR 0.152.
 */
package com.unfbx.chatgpt.utils;

import cn.hutool.core.util.StrUtil;
import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
import com.knuddels.jtokkit.api.EncodingType;
import com.knuddels.jtokkit.api.ModelType;
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
import com.unfbx.chatgpt.entity.chat.Message;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TikTokensUtil {
    private static final Logger log = LoggerFactory.getLogger(TikTokensUtil.class);

    public static List<Integer> encode(@NotNull Encoding enc, String text) {
        return StrUtil.isBlank((CharSequence)text) ? new ArrayList() : enc.encode(text);
    }

    public static int tokens(@NotNull Encoding enc, String text) {
        return TikTokensUtil.encode(enc, text).size();
    }

    public static String decode(@NotNull Encoding enc, @NotNull List<Integer> encoded) {
        return enc.decode(encoded);
    }

    public static Encoding getEncoding(@NotNull EncodingType encodingType) {
        EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
        Encoding enc = registry.getEncoding(encodingType);
        return enc;
    }

    public static List<Integer> encode(@NotNull EncodingType encodingType, String text) {
        if (StrUtil.isBlank((CharSequence)text)) {
            return new ArrayList<Integer>();
        }
        Encoding enc = TikTokensUtil.getEncoding(encodingType);
        List encoded = enc.encode(text);
        return encoded;
    }

    public static int tokens(@NotNull EncodingType encodingType, String text) {
        return TikTokensUtil.encode(encodingType, text).size();
    }

    public static String decode(@NotNull EncodingType encodingType, @NotNull List<Integer> encoded) {
        Encoding enc = TikTokensUtil.getEncoding(encodingType);
        return enc.decode(encoded);
    }

    public static Encoding getEncoding(@NotNull String modelName) {
        EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
        ModelType modelType = TikTokensUtil.getModelTypeByName(modelName);
        if (Objects.isNull(modelType)) {
            return null;
        }
        Encoding enc = registry.getEncodingForModel(modelType);
        return enc;
    }

    public static List<Integer> encode(@NotNull String modelName, String text) {
        if (StrUtil.isBlank((CharSequence)text)) {
            return new ArrayList<Integer>();
        }
        Encoding enc = TikTokensUtil.getEncoding(modelName);
        if (Objects.isNull(enc)) {
            log.warn("[{}]\u6a21\u578b\u4e0d\u5b58\u5728\u6216\u8005\u6682\u4e0d\u652f\u6301\u8ba1\u7b97tokens\uff0c\u76f4\u63a5\u8fd4\u56detokens==0");
            return new ArrayList<Integer>();
        }
        List encoded = enc.encode(text);
        return encoded;
    }

    public static int tokens(@NotNull String modelName, String text) {
        return TikTokensUtil.encode(modelName, text).size();
    }

    public static int tokens(@NotNull String modelName, @NotNull List<Message> messages) {
        Encoding encoding = TikTokensUtil.getEncoding(modelName);
        int tokensPerMessage = 0;
        int tokensPerName = 0;
        if (modelName.equals("gpt-3.5-turbo-0301") || modelName.equals("gpt-3.5-turbo")) {
            tokensPerMessage = 4;
            tokensPerName = -1;
        }
        if (modelName.equals("gpt-4") || modelName.equals("gpt-4-0314")) {
            tokensPerMessage = 3;
            tokensPerName = 1;
        }
        int sum = 0;
        for (Message msg : messages) {
            sum += tokensPerMessage;
            sum += TikTokensUtil.tokens(encoding, msg.getContent());
            sum += TikTokensUtil.tokens(encoding, msg.getRole());
            sum += TikTokensUtil.tokens(encoding, msg.getName());
            if (!StrUtil.isNotBlank((CharSequence)msg.getName())) continue;
            sum += tokensPerName;
        }
        return sum += 3;
    }

    public static String decode(@NotNull String modelName, @NotNull List<Integer> encoded) {
        Encoding enc = TikTokensUtil.getEncoding(modelName);
        return enc.decode(encoded);
    }

    private static ModelType getModelTypeByName(String name) {
        if (ChatCompletion.Model.GPT_3_5_TURBO_0301.getName().equals(name)) {
            return ModelType.GPT_3_5_TURBO;
        }
        if (ChatCompletion.Model.GPT_4.getName().equals(name) || ChatCompletion.Model.GPT_4_32K.getName().equals(name) || ChatCompletion.Model.GPT_4_32K_0314.getName().equals(name) || ChatCompletion.Model.GPT_4_0314.getName().equals(name)) {
            return ModelType.GPT_4;
        }
        for (ModelType modelType : ModelType.values()) {
            if (!modelType.getName().equals(name)) continue;
            return modelType;
        }
        log.warn("[{}]\u6a21\u578b\u4e0d\u5b58\u5728\u6216\u8005\u6682\u4e0d\u652f\u6301\u8ba1\u7b97tokens", (Object)name);
        return null;
    }
}

