/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.embedding.onnx;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.TokenCountEstimator;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;

public class HuggingFaceTokenCountEstimator
implements TokenCountEstimator {
    private final HuggingFaceTokenizer tokenizer;

    public HuggingFaceTokenCountEstimator() {
        HashMap<String, String> options = new HashMap<String, String>();
        options.put("padding", "false");
        options.put("truncation", "false");
        this.tokenizer = HuggingFaceTokenCountEstimator.createFrom(this.getClass().getResourceAsStream("/bert-tokenizer.json"), options);
    }

    public HuggingFaceTokenCountEstimator(Path pathToTokenizer) {
        this(pathToTokenizer, null);
    }

    public HuggingFaceTokenCountEstimator(Path pathToTokenizer, Map<String, String> options) {
        try {
            this.tokenizer = HuggingFaceTokenCountEstimator.createFrom(Files.newInputStream(pathToTokenizer, new OpenOption[0]), options);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public HuggingFaceTokenCountEstimator(String pathToTokenizer) {
        this(pathToTokenizer, null);
    }

    public HuggingFaceTokenCountEstimator(String pathToTokenizer, Map<String, String> options) {
        try {
            this.tokenizer = HuggingFaceTokenCountEstimator.createFrom(Files.newInputStream(Paths.get(pathToTokenizer, new String[0]), new OpenOption[0]), options);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static HuggingFaceTokenizer createFrom(InputStream tokenizer, Map<String, String> options) {
        try {
            return HuggingFaceTokenizer.newInstance((InputStream)tokenizer, options);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public int estimateTokenCountInText(String text) {
        Encoding encoding = this.tokenizer.encode(text, false, true);
        return encoding.getTokens().length;
    }

    public int estimateTokenCountInMessage(ChatMessage message) {
        if (message instanceof SystemMessage) {
            SystemMessage systemMessage = (SystemMessage)message;
            return this.estimateTokenCountInText(systemMessage.text());
        }
        if (message instanceof UserMessage) {
            UserMessage userMessage = (UserMessage)message;
            return this.estimateTokenCountInText(userMessage.singleText());
        }
        if (message instanceof AiMessage) {
            AiMessage aiMessage = (AiMessage)message;
            return this.estimateTokenCountInText(aiMessage.text());
        }
        if (message instanceof ToolExecutionResultMessage) {
            ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage)message;
            return this.estimateTokenCountInText(toolExecutionResultMessage.text());
        }
        throw new IllegalArgumentException("Unknown message type: " + String.valueOf(message));
    }

    public int estimateTokenCountInMessages(Iterable<ChatMessage> messages) {
        int tokens = 0;
        for (ChatMessage message : messages) {
            tokens += this.estimateTokenCountInMessage(message);
        }
        return tokens;
    }
}

