/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.jlama;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
import com.github.tjake.jlama.safetensors.prompt.Tool;
import com.github.tjake.jlama.safetensors.prompt.ToolCall;
import com.github.tjake.jlama.safetensors.prompt.ToolResult;
import com.github.tjake.jlama.util.JsonSupport;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ContentType;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Json;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.jlama.JlamaLanguageModel;
import dev.langchain4j.model.jlama.JlamaModel;
import dev.langchain4j.model.jlama.JlamaModelRegistry;
import dev.langchain4j.model.jlama.spi.JlamaChatModelBuilderFactory;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.spi.ServiceHelper;
import java.nio.file.Path;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;

public class JlamaChatModel
implements ChatLanguageModel {
    private final AbstractModel model;
    private final Float temperature;
    private final Integer maxTokens;

    public JlamaChatModel(Path modelCachePath, String modelName, String authToken, Integer threadCount, Boolean quantizeModelAtRuntime, Path workingDirectory, Float temperature, Integer maxTokens) {
        JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
        JlamaModel jlamaModel = (JlamaModel)RetryUtils.withRetry(() -> registry.downloadModel(modelName, Optional.ofNullable(authToken)), (int)3);
        JlamaModel.Loader loader = jlamaModel.loader();
        if (quantizeModelAtRuntime != null && quantizeModelAtRuntime.booleanValue()) {
            loader = loader.quantized();
        }
        if (threadCount != null) {
            loader = loader.threadCount(threadCount);
        }
        if (workingDirectory != null) {
            loader = loader.workingDirectory(workingDirectory);
        }
        this.model = loader.load();
        this.temperature = Float.valueOf(temperature == null ? 0.3f : temperature.floatValue());
        this.maxTokens = maxTokens == null ? this.model.getConfig().contextLength : maxTokens;
    }

    public static JlamaChatModelBuilder builder() {
        Iterator iterator = ServiceHelper.loadFactories(JlamaChatModelBuilderFactory.class).iterator();
        if (iterator.hasNext()) {
            JlamaChatModelBuilderFactory factory = (JlamaChatModelBuilderFactory)iterator.next();
            return (JlamaChatModelBuilder)factory.get();
        }
        return new JlamaChatModelBuilder();
    }

    public Response<AiMessage> generate(List<ChatMessage> messages) {
        return this.generate(messages, List.of());
    }

    public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
        if (this.model.promptSupport().isEmpty()) {
            throw new UnsupportedOperationException("This model does not support chat generation");
        }
        PromptSupport.Builder promptBuilder = ((PromptSupport)this.model.promptSupport().get()).builder();
        block6: for (ChatMessage message : messages) {
            switch (message.type()) {
                case SYSTEM: {
                    promptBuilder.addSystemMessage(((SystemMessage)message).text());
                    break;
                }
                case USER: {
                    StringBuilder finalMessage = new StringBuilder();
                    UserMessage userMessage = (UserMessage)message;
                    for (Content content : userMessage.contents()) {
                        if (content.type() != ContentType.TEXT) {
                            throw new UnsupportedOperationException("Unsupported content type: " + String.valueOf(content.type()));
                        }
                        finalMessage.append(((TextContent)content).text());
                    }
                    promptBuilder.addUserMessage(finalMessage.toString());
                    break;
                }
                case AI: {
                    AiMessage aiMessage = (AiMessage)message;
                    if (aiMessage.text() != null) {
                        promptBuilder.addAssistantMessage(aiMessage.text());
                    }
                    if (!aiMessage.hasToolExecutionRequests()) continue block6;
                    for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                        ToolCall toolCall = new ToolCall(toolExecutionRequest.name(), toolExecutionRequest.id(), (Map)Json.fromJson((String)toolExecutionRequest.arguments(), LinkedHashMap.class));
                        promptBuilder.addToolCall(toolCall);
                    }
                    continue block6;
                }
                case TOOL_EXECUTION_RESULT: {
                    ToolExecutionResultMessage toolMessage = (ToolExecutionResultMessage)message;
                    ToolResult result = ToolResult.from((String)toolMessage.toolName(), (String)toolMessage.id(), (Object)toolMessage.text());
                    promptBuilder.addToolResult(result);
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Unsupported message type: " + String.valueOf(message.type()));
                }
            }
        }
        List<Tool> tools = toolSpecifications.stream().map(JlamaModel::toTool).toList();
        PromptContext promptContext = tools.isEmpty() ? promptBuilder.build() : promptBuilder.build(tools);
        Generator.Response r = this.model.generate(UUID.randomUUID(), promptContext, this.temperature.floatValue(), this.maxTokens.intValue(), (token, time) -> {});
        if (r.finishReason == Generator.FinishReason.TOOL_CALL) {
            List<ToolExecutionRequest> toolCalls = r.toolCalls.stream().map(f -> ToolExecutionRequest.builder().name(f.getName()).id(f.getId()).arguments(JsonSupport.toJson((Object)f.getParameters())).build()).toList();
            return Response.from((Object)AiMessage.from(toolCalls), (TokenUsage)new TokenUsage(Integer.valueOf(r.promptTokens), Integer.valueOf(r.generatedTokens)), (FinishReason)JlamaLanguageModel.toFinishReason(r.finishReason));
        }
        return Response.from((Object)AiMessage.from((String)r.responseText), (TokenUsage)new TokenUsage(Integer.valueOf(r.promptTokens), Integer.valueOf(r.generatedTokens)), (FinishReason)JlamaLanguageModel.toFinishReason(r.finishReason));
    }

    public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
        return this.generate(messages, List.of(toolSpecification));
    }

    public static class JlamaChatModelBuilder {
        private Path modelCachePath;
        private String modelName;
        private String authToken;
        private Integer threadCount;
        private Boolean quantizeModelAtRuntime;
        private Path workingDirectory;
        private Float temperature;
        private Integer maxTokens;

        public JlamaChatModelBuilder modelCachePath(Path modelCachePath) {
            this.modelCachePath = modelCachePath;
            return this;
        }

        public JlamaChatModelBuilder modelName(String modelName) {
            this.modelName = modelName;
            return this;
        }

        public JlamaChatModelBuilder authToken(String authToken) {
            this.authToken = authToken;
            return this;
        }

        public JlamaChatModelBuilder threadCount(Integer threadCount) {
            this.threadCount = threadCount;
            return this;
        }

        public JlamaChatModelBuilder quantizeModelAtRuntime(Boolean quantizeModelAtRuntime) {
            this.quantizeModelAtRuntime = quantizeModelAtRuntime;
            return this;
        }

        public JlamaChatModelBuilder workingDirectory(Path workingDirectory) {
            this.workingDirectory = workingDirectory;
            return this;
        }

        public JlamaChatModelBuilder temperature(Float temperature) {
            this.temperature = temperature;
            return this;
        }

        public JlamaChatModelBuilder maxTokens(Integer maxTokens) {
            this.maxTokens = maxTokens;
            return this;
        }

        public JlamaChatModel build() {
            return new JlamaChatModel(this.modelCachePath, this.modelName, this.authToken, this.threadCount, this.quantizeModelAtRuntime, this.workingDirectory, this.temperature, this.maxTokens);
        }

        public String toString() {
            return "JlamaChatModel.JlamaChatModelBuilder(modelCachePath=" + String.valueOf(this.modelCachePath) + ", modelName=" + this.modelName + ", authToken=" + this.authToken + ", threadCount=" + this.threadCount + ", quantizeModelAtRuntime=" + this.quantizeModelAtRuntime + ", workingDirectory=" + String.valueOf(this.workingDirectory) + ", temperature=" + this.temperature + ", maxTokens=" + this.maxTokens + ")";
        }
    }
}

