/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.openai;

import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Exceptions;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.openai.InternalOpenAiHelper;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

public class OpenAiTokenizer
implements Tokenizer {
    private final String modelName;
    private final Optional<Encoding> encoding;

    public OpenAiTokenizer(String modelName) {
        this.modelName = modelName;
        this.encoding = Encodings.newLazyEncodingRegistry().getEncodingForModel(modelName);
    }

    public int estimateTokenCountInText(String text) {
        return this.encoding.orElseThrow(this.unknownModelException()).countTokensOrdinary(text);
    }

    public int estimateTokenCountInMessage(ChatMessage message) {
        AiMessage aiMessage;
        UserMessage userMessage;
        int tokenCount = 0;
        tokenCount += this.extraTokensPerMessage();
        tokenCount += this.estimateTokenCountInText(message.text());
        tokenCount += this.estimateTokenCountInText(InternalOpenAiHelper.roleFrom(message).toString());
        if (message instanceof UserMessage && (userMessage = (UserMessage)message).name() != null) {
            tokenCount += this.extraTokensPerName();
            tokenCount += this.estimateTokenCountInText(userMessage.name());
        }
        if (message instanceof AiMessage && (aiMessage = (AiMessage)message).toolExecutionRequest() != null) {
            tokenCount += 4;
            ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequest();
            tokenCount += this.estimateTokenCountInText(toolExecutionRequest.name());
            tokenCount += this.estimateTokenCountInText(toolExecutionRequest.arguments());
        }
        if (message instanceof ToolExecutionResultMessage) {
            ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage)message;
            --tokenCount;
            tokenCount += this.estimateTokenCountInText(toolExecutionResultMessage.toolName());
        }
        return tokenCount;
    }

    public int estimateTokenCountInMessages(Iterable<ChatMessage> messages) {
        int tokenCount = 3;
        for (ChatMessage message : messages) {
            tokenCount += this.estimateTokenCountInMessage(message);
        }
        return tokenCount;
    }

    public int estimateTokenCountInToolSpecifications(Iterable<ToolSpecification> toolSpecifications) {
        int tokenCount = 0;
        for (ToolSpecification toolSpecification : toolSpecifications) {
            tokenCount += this.estimateTokenCountInText(toolSpecification.name());
            tokenCount += this.estimateTokenCountInText(toolSpecification.description());
            Map properties = toolSpecification.parameters().properties();
            for (String property : properties.keySet()) {
                for (Map.Entry entry : ((Map)properties.get(property)).entrySet()) {
                    if ("type".equals(entry.getKey())) {
                        tokenCount += 3;
                        tokenCount += this.estimateTokenCountInText(entry.getValue().toString());
                        continue;
                    }
                    if ("description".equals(entry.getKey())) {
                        tokenCount += 3;
                        tokenCount += this.estimateTokenCountInText(entry.getValue().toString());
                        continue;
                    }
                    if (!"enum".equals(entry.getKey())) continue;
                    tokenCount -= 3;
                    for (Object enumValue : (Object[])entry.getValue()) {
                        tokenCount += 3;
                        tokenCount += this.estimateTokenCountInText(enumValue.toString());
                    }
                }
            }
            tokenCount += 12;
        }
        return tokenCount += 12;
    }

    private int extraTokensPerMessage() {
        if (this.modelName.equals("gpt-3.5-turbo-0301")) {
            return 4;
        }
        return 3;
    }

    private int extraTokensPerName() {
        if (this.modelName.equals("gpt-3.5-turbo-0301")) {
            return -1;
        }
        return 1;
    }

    public List<Integer> encode(String text) {
        return this.encoding.orElseThrow(this.unknownModelException()).encodeOrdinary(text);
    }

    public List<Integer> encode(String text, int maxTokensToEncode) {
        return this.encoding.orElseThrow(this.unknownModelException()).encodeOrdinary(text, maxTokensToEncode).getTokens();
    }

    public String decode(List<Integer> tokens) {
        return this.encoding.orElseThrow(this.unknownModelException()).decode(tokens);
    }

    private Supplier<IllegalArgumentException> unknownModelException() {
        return () -> Exceptions.illegalArgument((String)"Model '%s' is unknown to jtokkit", (Object[])new Object[]{this.modelName});
    }
}

