/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.googleai;

import com.google.gson.Gson;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.googleai.FunctionMapper;
import dev.langchain4j.model.googleai.GeminiContent;
import dev.langchain4j.model.googleai.GeminiCountTokensRequest;
import dev.langchain4j.model.googleai.GeminiCountTokensResponse;
import dev.langchain4j.model.googleai.GeminiError;
import dev.langchain4j.model.googleai.GeminiErrorContainer;
import dev.langchain4j.model.googleai.GeminiGenerateContentRequest;
import dev.langchain4j.model.googleai.GeminiPart;
import dev.langchain4j.model.googleai.GeminiService;
import dev.langchain4j.model.googleai.PartsAndContentsMapper;
import java.io.IOException;
import java.time.Duration;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import okhttp3.ResponseBody;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import retrofit2.Call;
import retrofit2.Response;

public class GoogleAiGeminiTokenizer
implements Tokenizer {
    private static final Logger log = LoggerFactory.getLogger(GoogleAiGeminiTokenizer.class);
    private static final Gson GSON = new Gson();
    private final GeminiService geminiService;
    private final String modelName;
    private final String apiKey;
    private final Integer maxRetries;

    GoogleAiGeminiTokenizer(String modelName, String apiKey, Boolean logRequestsAndResponses, Duration timeout, Integer maxRetries) {
        this.modelName = ValidationUtils.ensureNotBlank((String)modelName, (String)"modelName");
        this.apiKey = ValidationUtils.ensureNotBlank((String)apiKey, (String)"apiKey");
        this.maxRetries = (Integer)Utils.getOrDefault((Object)maxRetries, (Object)3);
        this.geminiService = GeminiService.getGeminiService((Logger)(logRequestsAndResponses != false ? log : null), timeout != null ? timeout : Duration.ofSeconds(60L));
    }

    public int estimateTokenCountInText(String text) {
        return this.estimateTokenCountInMessages(Collections.singletonList(UserMessage.from((String)text)));
    }

    public int estimateTokenCountInMessage(ChatMessage message) {
        return this.estimateTokenCountInMessages(Collections.singletonList(message));
    }

    public int estimateTokenCountInToolExecutionRequests(Iterable<ToolExecutionRequest> toolExecutionRequests) {
        LinkedList allToolRequests = new LinkedList();
        toolExecutionRequests.forEach(allToolRequests::add);
        return this.estimateTokenCountInMessage((ChatMessage)AiMessage.from(allToolRequests));
    }

    public int estimateTokenCountInMessages(Iterable<ChatMessage> messages) {
        LinkedList<ChatMessage> allMessages = new LinkedList<ChatMessage>();
        messages.forEach(allMessages::add);
        List<GeminiContent> geminiContentList = PartsAndContentsMapper.fromMessageToGContent(allMessages, null);
        GeminiCountTokensRequest countTokensRequest = new GeminiCountTokensRequest();
        countTokensRequest.setContents(geminiContentList);
        return this.estimateTokenCount(countTokensRequest);
    }

    public int estimateTokenCountInToolSpecifications(Iterable<ToolSpecification> toolSpecifications) {
        LinkedList<ToolSpecification> allTools = new LinkedList<ToolSpecification>();
        toolSpecifications.forEach(allTools::add);
        GeminiContent dummyContent = GeminiContent.builder().parts(Collections.singletonList(GeminiPart.builder().text("Dummy content").build())).build();
        GeminiCountTokensRequest countTokensRequestWithDummyContent = new GeminiCountTokensRequest();
        countTokensRequestWithDummyContent.setGenerateContentRequest(GeminiGenerateContentRequest.builder().model("models/" + this.modelName).contents(Collections.singletonList(dummyContent)).tools(FunctionMapper.fromToolSepcsToGTool(allTools, false)).build());
        return this.estimateTokenCount(countTokensRequestWithDummyContent) - 2;
    }

    private int estimateTokenCount(GeminiCountTokensRequest countTokensRequest) {
        GeminiCountTokensResponse countTokensResponse;
        block7: {
            Call responseCall = (Call)RetryUtils.withRetry(() -> this.geminiService.countTokens(this.modelName, this.apiKey, countTokensRequest), (int)this.maxRetries);
            try {
                Response executed = responseCall.execute();
                countTokensResponse = (GeminiCountTokensResponse)executed.body();
                if (executed.code() < 300) break block7;
                ResponseBody errorBody = executed.errorBody();
                try {
                    GeminiError error = ((GeminiErrorContainer)GSON.fromJson(errorBody.string(), GeminiErrorContainer.class)).getError();
                    throw new RuntimeException(String.format("%s (code %d) %s", error.getStatus(), error.getCode(), error.getMessage()));
                }
                catch (Throwable throwable) {
                    if (errorBody != null) {
                        try {
                            errorBody.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
            }
            catch (IOException e) {
                throw new RuntimeException("An error occurred when calling the Gemini API endpoint to calculate tokens count", e);
            }
        }
        return countTokensResponse.getTotalTokens();
    }

    public static GoogleAiGeminiTokenizerBuilder builder() {
        return new GoogleAiGeminiTokenizerBuilder();
    }

    public static class GoogleAiGeminiTokenizerBuilder {
        private String modelName;
        private String apiKey;
        private Boolean logRequestsAndResponses;
        private Duration timeout;
        private Integer maxRetries;

        GoogleAiGeminiTokenizerBuilder() {
        }

        public GoogleAiGeminiTokenizerBuilder modelName(String modelName) {
            this.modelName = modelName;
            return this;
        }

        public GoogleAiGeminiTokenizerBuilder apiKey(String apiKey) {
            this.apiKey = apiKey;
            return this;
        }

        public GoogleAiGeminiTokenizerBuilder logRequestsAndResponses(Boolean logRequestsAndResponses) {
            this.logRequestsAndResponses = logRequestsAndResponses;
            return this;
        }

        public GoogleAiGeminiTokenizerBuilder timeout(Duration timeout) {
            this.timeout = timeout;
            return this;
        }

        public GoogleAiGeminiTokenizerBuilder maxRetries(Integer maxRetries) {
            this.maxRetries = maxRetries;
            return this;
        }

        public GoogleAiGeminiTokenizer build() {
            return new GoogleAiGeminiTokenizer(this.modelName, this.apiKey, this.logRequestsAndResponses, this.timeout, this.maxRetries);
        }

        public String toString() {
            return "GoogleAiGeminiTokenizer.GoogleAiGeminiTokenizerBuilder(modelName=" + this.modelName + ", apiKey=" + this.apiKey + ", logRequestsAndResponses=" + this.logRequestsAndResponses + ", timeout=" + this.timeout + ", maxRetries=" + this.maxRetries + ")";
        }
    }
}

