/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.ollama;

import com.fasterxml.jackson.core.JsonProcessingException;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
import io.quarkiverse.langchain4j.ollama.ChatRequest;
import io.quarkiverse.langchain4j.ollama.ChatResponse;
import io.quarkiverse.langchain4j.ollama.MessageMapper;
import io.quarkiverse.langchain4j.ollama.OllamaClient;
import io.quarkiverse.langchain4j.ollama.Options;
import io.quarkiverse.langchain4j.ollama.ToolCall;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import org.jboss.logging.Logger;

public class OllamaChatLanguageModel
implements ChatLanguageModel {
    private static final Logger log = Logger.getLogger(OllamaChatLanguageModel.class);
    private final OllamaClient client;
    private final String model;
    private final String format;
    private final Options options;
    private final List<ChatModelListener> listeners;

    private OllamaChatLanguageModel(Builder builder) {
        this.client = new OllamaClient(builder.baseUrl, builder.timeout, builder.logRequests, builder.logResponses, builder.configName, builder.tlsConfigurationName);
        this.model = builder.model;
        this.format = builder.format;
        this.options = builder.options;
        this.listeners = builder.listeners;
    }

    public static Builder builder() {
        return new Builder();
    }

    public Response<AiMessage> generate(List<ChatMessage> messages) {
        return this.generate(messages, Collections.emptyList());
    }

    public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
        return this.generate(messages, toolSpecification != null ? Collections.singletonList(toolSpecification) : Collections.emptyList());
    }

    public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
        ValidationUtils.ensureNotEmpty(messages, (String)"messages");
        ChatRequest request = ChatRequest.builder().model(this.model).messages(MessageMapper.toOllamaMessages(messages)).tools(MessageMapper.toTools(toolSpecifications)).options(this.options).format(this.format).stream(false).build();
        ChatModelRequest modelListenerRequest = this.createModelListenerRequest(request, messages, toolSpecifications);
        ConcurrentHashMap attributes = new ConcurrentHashMap();
        ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
        this.listeners.forEach(listener -> {
            try {
                listener.onRequest(requestContext);
            }
            catch (Exception e) {
                log.warn((Object)"Exception while calling model listener", (Throwable)e);
            }
        });
        try {
            ChatResponse chatResponse = this.client.chat(request);
            Response<AiMessage> response = OllamaChatLanguageModel.toResponse(chatResponse);
            ChatModelResponse modelListenerResponse = this.createModelListenerResponse(null, chatResponse.model(), response);
            ChatModelResponseContext responseContext = new ChatModelResponseContext(modelListenerResponse, modelListenerRequest, attributes);
            this.listeners.forEach(listener -> {
                try {
                    listener.onResponse(responseContext);
                }
                catch (Exception e) {
                    log.warn((Object)"Exception while calling model listener", (Throwable)e);
                }
            });
            return response;
        }
        catch (RuntimeException e) {
            ChatModelErrorContext errorContext = new ChatModelErrorContext((Throwable)e, modelListenerRequest, null, attributes);
            this.listeners.forEach(listener -> {
                try {
                    listener.onError(errorContext);
                }
                catch (Exception e2) {
                    log.warn((Object)"Exception while calling model listener", (Throwable)e2);
                }
            });
            throw e;
        }
    }

    private static Response<AiMessage> toResponse(ChatResponse response) {
        Response result;
        List<ToolCall> toolCalls = response.message().toolCalls();
        if (toolCalls == null || toolCalls.isEmpty()) {
            result = Response.from((Object)AiMessage.from((String)response.message().content()), (TokenUsage)new TokenUsage(response.promptEvalCount(), response.evalCount()));
        } else {
            try {
                ArrayList<ToolExecutionRequest> toolExecutionRequests = new ArrayList<ToolExecutionRequest>(toolCalls.size());
                for (ToolCall toolCall : toolCalls) {
                    ToolCall.FunctionCall functionCall = toolCall.function();
                    String argumentsStr = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER.writeValueAsString(functionCall.arguments());
                    toolExecutionRequests.add(ToolExecutionRequest.builder().name(functionCall.name()).arguments(argumentsStr).build());
                }
                result = Response.from((Object)AiMessage.aiMessage(toolExecutionRequests), (TokenUsage)new TokenUsage(response.promptEvalCount(), response.evalCount()));
            }
            catch (JsonProcessingException e) {
                throw new RuntimeException("Unable to parse tool call response", e);
            }
        }
        return result;
    }

    private ChatModelRequest createModelListenerRequest(ChatRequest request, List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
        Options options = request.options();
        ChatModelRequest.ChatModelRequestBuilder builder = ChatModelRequest.builder().model(request.model()).messages(messages).toolSpecifications(toolSpecifications);
        if (options != null) {
            builder.temperature(options.temperature()).topP(options.topP()).maxTokens(options.numPredict());
        }
        return builder.build();
    }

    private ChatModelResponse createModelListenerResponse(String responseId, String responseModel, Response<AiMessage> response) {
        if (response == null) {
            return null;
        }
        return ChatModelResponse.builder().id(responseId).model(responseModel).tokenUsage(response.tokenUsage()).finishReason(response.finishReason()).aiMessage((AiMessage)response.content()).build();
    }

    public static final class Builder {
        private String baseUrl = "http://localhost:11434";
        private String tlsConfigurationName;
        private Duration timeout = Duration.ofSeconds(10L);
        private String model;
        private String format;
        private Options options;
        private boolean logRequests = false;
        private boolean logResponses = false;
        private String configName;
        private List<ChatModelListener> listeners = Collections.emptyList();

        private Builder() {
        }

        public Builder baseUrl(String val) {
            this.baseUrl = val;
            return this;
        }

        public Builder tlsConfigurationName(String tlsConfigurationName) {
            this.tlsConfigurationName = tlsConfigurationName;
            return this;
        }

        public Builder timeout(Duration val) {
            this.timeout = val;
            return this;
        }

        public Builder model(String val) {
            this.model = val;
            return this;
        }

        public Builder format(String val) {
            this.format = val;
            return this;
        }

        public Builder options(Options val) {
            this.options = val;
            return this;
        }

        public Builder logRequests(boolean logRequests) {
            this.logRequests = logRequests;
            return this;
        }

        public Builder logResponses(boolean logResponses) {
            this.logResponses = logResponses;
            return this;
        }

        public Builder configName(String configName) {
            this.configName = configName;
            return this;
        }

        public Builder listeners(List<ChatModelListener> listeners) {
            this.listeners = listeners;
            return this;
        }

        public OllamaChatLanguageModel build() {
            return new OllamaChatLanguageModel(this);
        }
    }
}

