/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.ollama;

import com.fasterxml.jackson.core.type.TypeReference;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationConvention;
import io.micrometer.observation.ObservationRegistry;
import java.time.Duration;
import java.util.Base64;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.api.common.OllamaApiConstants;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;

public class OllamaChatModel
implements ChatModel {
    private static final Logger logger = LoggerFactory.getLogger(OllamaChatModel.class);
    private static final String DONE = "done";
    private static final String METADATA_PROMPT_EVAL_COUNT = "prompt-eval-count";
    private static final String METADATA_EVAL_COUNT = "eval-count";
    private static final String METADATA_CREATED_AT = "created-at";
    private static final String METADATA_TOTAL_DURATION = "total-duration";
    private static final String METADATA_LOAD_DURATION = "load-duration";
    private static final String METADATA_PROMPT_EVAL_DURATION = "prompt-eval-duration";
    private static final String METADATA_EVAL_DURATION = "eval-duration";
    private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
    private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build();
    private final OllamaApi chatApi;
    private final OllamaOptions defaultOptions;
    private final ObservationRegistry observationRegistry;
    private final OllamaModelManager modelManager;
    private final ToolCallingManager toolCallingManager;
    private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate;
    private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

    public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
        this(ollamaApi, defaultOptions, toolCallingManager, observationRegistry, modelManagementOptions, (ToolExecutionEligibilityPredicate)new DefaultToolExecutionEligibilityPredicate());
    }

    public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions, ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
        Assert.notNull((Object)ollamaApi, (String)"ollamaApi must not be null");
        Assert.notNull((Object)defaultOptions, (String)"defaultOptions must not be null");
        Assert.notNull((Object)toolCallingManager, (String)"toolCallingManager must not be null");
        Assert.notNull((Object)observationRegistry, (String)"observationRegistry must not be null");
        Assert.notNull((Object)modelManagementOptions, (String)"modelManagementOptions must not be null");
        Assert.notNull((Object)toolExecutionEligibilityPredicate, (String)"toolExecutionEligibilityPredicate must not be null");
        this.chatApi = ollamaApi;
        this.defaultOptions = defaultOptions;
        this.toolCallingManager = toolCallingManager;
        this.observationRegistry = observationRegistry;
        this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions);
        this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
        this.initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
    }

    public static Builder builder() {
        return new Builder();
    }

    static ChatResponseMetadata from(OllamaApi.ChatResponse response, ChatResponse previousChatResponse) {
        Assert.notNull((Object)response, (String)"OllamaApi.ChatResponse must not be null");
        DefaultUsage newUsage = OllamaChatModel.getDefaultUsage(response);
        Integer promptTokens = newUsage.getPromptTokens();
        Integer generationTokens = newUsage.getCompletionTokens();
        int totalTokens = newUsage.getTotalTokens();
        Duration evalDuration = response.getEvalDuration();
        Duration promptEvalDuration = response.getPromptEvalDuration();
        Duration loadDuration = response.getLoadDuration();
        Duration totalDuration = response.getTotalDuration();
        if (previousChatResponse != null && previousChatResponse.getMetadata() != null) {
            if (previousChatResponse.getMetadata().get(METADATA_EVAL_DURATION) != null) {
                evalDuration = evalDuration.plus((Duration)previousChatResponse.getMetadata().get(METADATA_EVAL_DURATION));
            }
            if (previousChatResponse.getMetadata().get(METADATA_PROMPT_EVAL_DURATION) != null) {
                promptEvalDuration = promptEvalDuration.plus((Duration)previousChatResponse.getMetadata().get(METADATA_PROMPT_EVAL_DURATION));
            }
            if (previousChatResponse.getMetadata().get(METADATA_LOAD_DURATION) != null) {
                loadDuration = loadDuration.plus((Duration)previousChatResponse.getMetadata().get(METADATA_LOAD_DURATION));
            }
            if (previousChatResponse.getMetadata().get(METADATA_TOTAL_DURATION) != null) {
                totalDuration = totalDuration.plus((Duration)previousChatResponse.getMetadata().get(METADATA_TOTAL_DURATION));
            }
            if (previousChatResponse.getMetadata().getUsage() != null) {
                promptTokens = promptTokens + previousChatResponse.getMetadata().getUsage().getPromptTokens();
                generationTokens = generationTokens + previousChatResponse.getMetadata().getUsage().getCompletionTokens();
                totalTokens += previousChatResponse.getMetadata().getUsage().getTotalTokens().intValue();
            }
        }
        DefaultUsage aggregatedUsage = new DefaultUsage(promptTokens, generationTokens, Integer.valueOf(totalTokens));
        return ChatResponseMetadata.builder().usage((Usage)aggregatedUsage).model(response.model()).keyValue(METADATA_CREATED_AT, (Object)response.createdAt()).keyValue(METADATA_EVAL_DURATION, (Object)evalDuration).keyValue(METADATA_EVAL_COUNT, (Object)aggregatedUsage.getCompletionTokens()).keyValue(METADATA_LOAD_DURATION, (Object)loadDuration).keyValue(METADATA_PROMPT_EVAL_DURATION, (Object)promptEvalDuration).keyValue(METADATA_PROMPT_EVAL_COUNT, (Object)aggregatedUsage.getPromptTokens()).keyValue(METADATA_TOTAL_DURATION, (Object)totalDuration).keyValue(DONE, (Object)response.done()).build();
    }

    private static DefaultUsage getDefaultUsage(OllamaApi.ChatResponse response) {
        return new DefaultUsage(Optional.ofNullable(response.promptEvalCount()).orElse(0), Optional.ofNullable(response.evalCount()).orElse(0));
    }

    public ChatResponse call(Prompt prompt) {
        Prompt requestPrompt = this.buildRequestPrompt(prompt);
        return this.internalCall(requestPrompt, null);
    }

    private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
        OllamaApi.ChatRequest request = this.ollamaChatRequest(prompt, false);
        ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(OllamaApiConstants.PROVIDER_NAME).build();
        ChatResponse response = (ChatResponse)ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation((ObservationConvention)this.observationConvention, (ObservationConvention)DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry).observe(() -> {
            OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request);
            List toolCalls = ollamaResponse.message().toolCalls() == null ? List.of() : ollamaResponse.message().toolCalls().stream().map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(), ModelOptionsUtils.toJsonString(toolCall.function().arguments()))).toList();
            AssistantMessage assistantMessage = new AssistantMessage(ollamaResponse.message().content(), Map.of(), toolCalls);
            ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
            if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) {
                generationMetadata = ChatGenerationMetadata.builder().finishReason(ollamaResponse.doneReason()).build();
            }
            Generation generator = new Generation(assistantMessage, generationMetadata);
            ChatResponse chatResponse = new ChatResponse(List.of(generator), OllamaChatModel.from(ollamaResponse, previousChatResponse));
            observationContext.setResponse((Object)chatResponse);
            return chatResponse;
        });
        if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
            ToolExecutionResult toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
            if (toolExecutionResult.returnDirect()) {
                return ChatResponse.builder().from(response).generations(ToolExecutionResult.buildGenerations((ToolExecutionResult)toolExecutionResult)).build();
            }
            return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response);
        }
        return response;
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        Prompt requestPrompt = this.buildRequestPrompt(prompt);
        return this.internalStream(requestPrompt, null);
    }

    private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
        return Flux.deferContextual(contextView -> {
            OllamaApi.ChatRequest request = this.ollamaChatRequest(prompt, true);
            ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(OllamaApiConstants.PROVIDER_NAME).build();
            Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation((ObservationConvention)this.observationConvention, (ObservationConvention)DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry);
            observation.parentObservation((Observation)contextView.getOrDefault((Object)"micrometer.observation", null)).start();
            Flux<OllamaApi.ChatResponse> ollamaResponse = this.chatApi.streamingChat(request);
            Flux chatResponse = ollamaResponse.map(chunk -> {
                String content = chunk.message() != null ? chunk.message().content() : "";
                List<Object> toolCalls = List.of();
                if (chunk.message() != null && chunk.message().toolCalls() != null) {
                    toolCalls = chunk.message().toolCalls().stream().map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(), ModelOptionsUtils.toJsonString(toolCall.function().arguments()))).toList();
                }
                AssistantMessage assistantMessage = new AssistantMessage(content, Map.of(), toolCalls);
                ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
                if (chunk.promptEvalCount() != null && chunk.evalCount() != null) {
                    generationMetadata = ChatGenerationMetadata.builder().finishReason(chunk.doneReason()).build();
                }
                Generation generator = new Generation(assistantMessage, generationMetadata);
                return new ChatResponse(List.of(generator), OllamaChatModel.from(chunk, previousChatResponse));
            });
            Flux chatResponseFlux = chatResponse.flatMap(response -> {
                if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
                    return Flux.defer(() -> {
                        ToolExecutionResult toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
                        if (toolExecutionResult.returnDirect()) {
                            return Flux.just((Object)ChatResponse.builder().from(response).generations(ToolExecutionResult.buildGenerations((ToolExecutionResult)toolExecutionResult)).build());
                        }
                        return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), (ChatResponse)response);
                    }).subscribeOn(Schedulers.boundedElastic());
                }
                return Flux.just((Object)response);
            }).doOnError(arg_0 -> ((Observation)observation).error(arg_0)).doFinally(s -> observation.stop()).contextWrite(ctx -> ctx.put((Object)"micrometer.observation", (Object)observation));
            return new MessageAggregator().aggregate(chatResponseFlux, arg_0 -> ((ChatModelObservationContext)observationContext).setResponse(arg_0));
        });
    }

    Prompt buildRequestPrompt(Prompt prompt) {
        OllamaOptions runtimeOptions = null;
        if (prompt.getOptions() != null) {
            ChatOptions chatOptions = prompt.getOptions();
            if (chatOptions instanceof ToolCallingChatOptions) {
                ToolCallingChatOptions toolCallingChatOptions = (ToolCallingChatOptions)chatOptions;
                runtimeOptions = (OllamaOptions)ModelOptionsUtils.copyToTarget((Object)toolCallingChatOptions, ToolCallingChatOptions.class, OllamaOptions.class);
            } else {
                runtimeOptions = (OllamaOptions)ModelOptionsUtils.copyToTarget((Object)prompt.getOptions(), ChatOptions.class, OllamaOptions.class);
            }
        }
        OllamaOptions requestOptions = (OllamaOptions)ModelOptionsUtils.merge(runtimeOptions, (Object)this.defaultOptions, OllamaOptions.class);
        if (runtimeOptions != null) {
            requestOptions.setInternalToolExecutionEnabled((Boolean)ModelOptionsUtils.mergeOption((Object)runtimeOptions.getInternalToolExecutionEnabled(), (Object)this.defaultOptions.getInternalToolExecutionEnabled()));
            requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames()));
            requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), this.defaultOptions.getToolCallbacks()));
            requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), this.defaultOptions.getToolContext()));
        } else {
            requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
            requestOptions.setToolNames(this.defaultOptions.getToolNames());
            requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
            requestOptions.setToolContext(this.defaultOptions.getToolContext());
        }
        if (!StringUtils.hasText((String)requestOptions.getModel())) {
            throw new IllegalArgumentException("model cannot be null or empty");
        }
        ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());
        return new Prompt(prompt.getInstructions(), (ChatOptions)requestOptions);
    }

    OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) {
        List toolDefinitions;
        List<OllamaApi.Message> ollamaMessages = prompt.getInstructions().stream().map(message -> {
            if (message instanceof UserMessage) {
                UserMessage userMessage = (UserMessage)message;
                OllamaApi.Message.Builder messageBuilder = OllamaApi.Message.builder(OllamaApi.Message.Role.USER).content(message.getText());
                if (!CollectionUtils.isEmpty((Collection)userMessage.getMedia())) {
                    messageBuilder.images(userMessage.getMedia().stream().map(media -> this.fromMediaData(media.getData())).toList());
                }
                return List.of(messageBuilder.build());
            }
            if (message instanceof SystemMessage) {
                SystemMessage systemMessage = (SystemMessage)message;
                return List.of(OllamaApi.Message.builder(OllamaApi.Message.Role.SYSTEM).content(systemMessage.getText()).build());
            }
            if (message instanceof AssistantMessage) {
                AssistantMessage assistantMessage = (AssistantMessage)message;
                List<OllamaApi.Message.ToolCall> toolCalls = null;
                if (!CollectionUtils.isEmpty((Collection)assistantMessage.getToolCalls())) {
                    toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
                        OllamaApi.Message.ToolCallFunction function = new OllamaApi.Message.ToolCallFunction(toolCall.name(), (Map)JsonParser.fromJson((String)toolCall.arguments(), (TypeReference)new TypeReference<Map<String, Object>>(){}));
                        return new OllamaApi.Message.ToolCall(function);
                    }).toList();
                }
                return List.of(OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT).content(assistantMessage.getText()).toolCalls(toolCalls).build());
            }
            if (message instanceof ToolResponseMessage) {
                ToolResponseMessage toolMessage = (ToolResponseMessage)message;
                return toolMessage.getResponses().stream().map(tr -> OllamaApi.Message.builder(OllamaApi.Message.Role.TOOL).content(tr.responseData()).build()).toList();
            }
            throw new IllegalArgumentException("Unsupported message type: " + String.valueOf(message.getMessageType()));
        }).flatMap(Collection::stream).toList();
        OllamaOptions requestOptions = (OllamaOptions)prompt.getOptions();
        OllamaApi.ChatRequest.Builder requestBuilder = OllamaApi.ChatRequest.builder(requestOptions.getModel()).stream(stream).messages(ollamaMessages).options(requestOptions);
        if (requestOptions.getFormat() != null) {
            requestBuilder.format(requestOptions.getFormat());
        }
        if (requestOptions.getKeepAlive() != null) {
            requestBuilder.keepAlive(requestOptions.getKeepAlive());
        }
        if (!CollectionUtils.isEmpty((Collection)(toolDefinitions = this.toolCallingManager.resolveToolDefinitions((ToolCallingChatOptions)requestOptions)))) {
            requestBuilder.tools(this.getTools(toolDefinitions));
        }
        return requestBuilder.build();
    }

    private String fromMediaData(Object mediaData) {
        if (mediaData instanceof byte[]) {
            byte[] bytes = (byte[])mediaData;
            return Base64.getEncoder().encodeToString(bytes);
        }
        if (mediaData instanceof String) {
            String text = (String)mediaData;
            return text;
        }
        throw new IllegalArgumentException("Unsupported media data type: " + mediaData.getClass().getSimpleName());
    }

    private List<OllamaApi.ChatRequest.Tool> getTools(List<ToolDefinition> toolDefinitions) {
        return toolDefinitions.stream().map(toolDefinition -> {
            OllamaApi.ChatRequest.Tool.Function tool = new OllamaApi.ChatRequest.Tool.Function(toolDefinition.name(), toolDefinition.description(), toolDefinition.inputSchema());
            return new OllamaApi.ChatRequest.Tool(tool);
        }).toList();
    }

    public ChatOptions getDefaultOptions() {
        return OllamaOptions.fromOptions(this.defaultOptions);
    }

    private void initializeModel(String model, PullModelStrategy pullModelStrategy) {
        if (pullModelStrategy != null && !PullModelStrategy.NEVER.equals((Object)pullModelStrategy)) {
            this.modelManager.pullModel(model, pullModelStrategy);
        }
    }

    public void setObservationConvention(ChatModelObservationConvention observationConvention) {
        Assert.notNull((Object)observationConvention, (String)"observationConvention cannot be null");
        this.observationConvention = observationConvention;
    }

    public static final class Builder {
        private OllamaApi ollamaApi;
        private OllamaOptions defaultOptions = OllamaOptions.builder().model(OllamaModel.MISTRAL.id()).build();
        private ToolCallingManager toolCallingManager;
        private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate();
        private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
        private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();

        private Builder() {
        }

        public Builder ollamaApi(OllamaApi ollamaApi) {
            this.ollamaApi = ollamaApi;
            return this;
        }

        public Builder defaultOptions(OllamaOptions defaultOptions) {
            this.defaultOptions = defaultOptions;
            return this;
        }

        public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
            this.toolCallingManager = toolCallingManager;
            return this;
        }

        public Builder toolExecutionEligibilityPredicate(ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
            this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
            return this;
        }

        public Builder observationRegistry(ObservationRegistry observationRegistry) {
            this.observationRegistry = observationRegistry;
            return this;
        }

        public Builder modelManagementOptions(ModelManagementOptions modelManagementOptions) {
            this.modelManagementOptions = modelManagementOptions;
            return this;
        }

        public OllamaChatModel build() {
            if (this.toolCallingManager != null) {
                return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.toolCallingManager, this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate);
            }
            return new OllamaChatModel(this.ollamaApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate);
        }
    }
}

