/*
 * Decompiled with CFR 0.152.
 */
package com.mulesoft.connectors.inference.internal.helpers.mcp;

import com.fasterxml.jackson.core.JacksonException;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.mulesoft.connectors.inference.api.mcp.McpConfig;
import com.mulesoft.connectors.inference.api.request.Function;
import com.mulesoft.connectors.inference.api.request.FunctionSchema;
import com.mulesoft.connectors.inference.api.response.ToolCall;
import com.mulesoft.connectors.inference.api.response.ToolResult;
import com.mulesoft.connectors.inference.internal.dto.mcp.McpServerToolDTO;
import com.mulesoft.connectors.inference.internal.dto.mcp.McpToolRecord;
import com.mulesoft.connectors.inference.internal.error.InferenceErrorType;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Spliterators;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.StreamSupport;
import org.mule.runtime.api.scheduler.Scheduler;
import org.mule.runtime.extension.api.client.ExtensionsClient;
import org.mule.runtime.extension.api.client.OperationParameterizer;
import org.mule.runtime.extension.api.error.ErrorTypeDefinition;
import org.mule.runtime.extension.api.exception.ModuleException;
import org.mule.runtime.extension.api.runtime.operation.Result;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class McpHelper {
    private static final Logger logger = LoggerFactory.getLogger(McpHelper.class);
    private static final String MCP = "MCP";
    private final ObjectMapper objectMapper;
    private final Map<String, List<McpToolRecord>> toolsByServer = new ConcurrentHashMap<String, List<McpToolRecord>>();

    public McpHelper(ObjectMapper objectMapper) {
        this.objectMapper = objectMapper;
    }

    public CompletableFuture<Map<String, McpToolRecord>> getTools(List<McpConfig> mcpConfigs, Scheduler scheduler, ExtensionsClient extensionsClient) {
        if (null == mcpConfigs || mcpConfigs.isEmpty()) {
            throw new ModuleException("MCP configuration is required", (ErrorTypeDefinition)InferenceErrorType.MCP_TOOLS_OPERATION_FAILURE);
        }
        return new McpDiscovery(mcpConfigs, scheduler, extensionsClient).getDiscoveredTools();
    }

    public List<ToolResult> executeTools(Map<String, McpToolRecord> collectedTools, List<ToolCall> toolCalls, ExtensionsClient extensionsClient) {
        if (toolCalls == null || toolCalls.isEmpty()) {
            return new ArrayList<ToolResult>();
        }
        return toolCalls.stream().map(toolCall -> this.executeToolCall((ToolCall)toolCall, collectedTools, extensionsClient)).toList();
    }

    private ToolResult executeToolCall(ToolCall toolCall, Map<String, McpToolRecord> collectedTools, ExtensionsClient extensionsClient) {
        String toolName = toolCall.function().name();
        return Optional.ofNullable(collectedTools.get(toolName)).map(tool -> this.parseArgumentsAndExecute((McpToolRecord)tool, toolCall, extensionsClient)).orElseThrow(() -> {
            logger.error("Tool '{}' not found in collected tools", (Object)toolName);
            return new ModuleException("Tool '" + toolName + "' not found in collected tools", (ErrorTypeDefinition)InferenceErrorType.MCP_SERVER_ERROR);
        });
    }

    private ToolResult parseArgumentsAndExecute(McpToolRecord tool, ToolCall toolCall, ExtensionsClient extensionsClient) {
        logger.debug("Executing tool call:{}", (Object)tool);
        try {
            Map args = (Map)this.objectMapper.readValue(toolCall.function().arguments(), (JavaType)this.objectMapper.getTypeFactory().constructMapType(Map.class, String.class, Object.class));
            return this.executeToolWithErrorHandling(tool, args, extensionsClient);
        }
        catch (JsonProcessingException e) {
            throw new ModuleException("Failed to execute tool '" + tool.getName() + "': " + e.getMessage(), (ErrorTypeDefinition)InferenceErrorType.MCP_SERVER_ERROR, (Throwable)e);
        }
    }

    private ToolResult executeToolWithErrorHandling(McpToolRecord tool, Map<String, Object> args, ExtensionsClient extensionsClient) {
        return (ToolResult)((CompletableFuture)this.invokeMcpCallTool(tool, args, extensionsClient).exceptionally(toolException -> {
            throw new ModuleException("Error executing tool '" + tool.getName() + "': " + toolException.getMessage(), (ErrorTypeDefinition)InferenceErrorType.MCP_SERVER_ERROR, toolException);
        })).join();
    }

    private CompletableFuture<ToolResult> invokeMcpCallTool(McpToolRecord tool, Map<String, Object> args, ExtensionsClient extensionsClient) {
        return extensionsClient.execute(MCP, "callTool", params -> ((OperationParameterizer)((OperationParameterizer)params.withConfigRef(tool.configRef())).withParameter("toolName", (Object)tool.originalName())).withParameter("arguments", (Object)args)).thenApply(result -> new ToolResult(tool.originalName(), result.getOutput(), tool.configRef(), Instant.now()));
    }

    private class McpDiscovery {
        private final List<McpConfig> mcpConfigs;
        private final Map<String, McpToolRecord> discoveredTools = new ConcurrentHashMap<String, McpToolRecord>();
        private final AtomicInteger countDown;
        private final CompletableFuture<Map<String, McpToolRecord>> future = new CompletableFuture();
        private final Scheduler scheduler;
        private final ExtensionsClient extensionsClient;

        private McpDiscovery(List<McpConfig> mcpConfigs, Scheduler scheduler, ExtensionsClient extensionsClient) {
            this.mcpConfigs = mcpConfigs;
            this.countDown = new AtomicInteger(mcpConfigs.size());
            this.scheduler = scheduler;
            this.extensionsClient = extensionsClient;
        }

        public CompletableFuture<Map<String, McpToolRecord>> getDiscoveredTools() {
            try {
                for (McpConfig mcpConfig : this.mcpConfigs) {
                    String mcpConfigRef = mcpConfig.getMcpClientConfigRef();
                    List<McpToolRecord> mcpMcpToolRecords = McpHelper.this.toolsByServer.get(mcpConfigRef);
                    if (mcpMcpToolRecords != null) {
                        this.collect(mcpMcpToolRecords);
                        continue;
                    }
                    this.scheduler.submit(() -> this.invokeMcpListTools(mcpConfig));
                }
            }
            catch (Exception ex) {
                this.future.completeExceptionally(ex);
            }
            return this.future;
        }

        private void invokeMcpListTools(McpConfig mcpConfig) {
            String mcpConfigRef = mcpConfig.getMcpClientConfigRef();
            this.extensionsClient.execute(McpHelper.MCP, "listTools", params -> params.withConfigRef(mcpConfigRef)).whenComplete((result, t) -> Optional.ofNullable(t).ifPresentOrElse(throwable -> this.handleToolParsingException(this.countDown, (Throwable)throwable, mcpConfigRef), () -> this.processToolsResult((Result<Object, Object>)result, mcpConfigRef)));
        }

        private void processToolsResult(Result<Object, Object> result, String mcpConfigRef) {
            try {
                List<McpToolRecord> mcpTools = this.createTools(result, mcpConfigRef);
                Optional.of(mcpTools).filter(tools -> !tools.isEmpty()).ifPresent(tools -> {
                    this.collect((Collection<McpToolRecord>)tools);
                    McpHelper.this.toolsByServer.put(mcpConfigRef, (List<McpToolRecord>)tools);
                });
            }
            catch (Exception e) {
                this.handleToolParsingException(this.countDown, e, mcpConfigRef);
            }
        }

        private List<McpToolRecord> createTools(Result<Object, Object> result, String mcpConfigRef) {
            Object resultOutput = result.getOutput();
            if (!(resultOutput instanceof Iterator)) {
                logger.error("Expected Iterator but got: {}", (Object)resultOutput.getClass().getName());
                return new ArrayList<McpToolRecord>();
            }
            Iterator iterator = (Iterator)resultOutput;
            return StreamSupport.stream(Spliterators.spliteratorUnknownSize(iterator, 16), false).map(obj -> {
                try {
                    McpServerToolDTO serverToolDTO = (McpServerToolDTO)McpHelper.this.objectMapper.convertValue(obj, McpServerToolDTO.class);
                    String originalToolName = serverToolDTO.name();
                    String description = serverToolDTO.description();
                    logger.debug("Server:{}, Tool details from server: {}", (Object)mcpConfigRef, (Object)serverToolDTO);
                    FunctionSchema toolSchema = serverToolDTO.inputSchema() != null ? (FunctionSchema)McpHelper.this.objectMapper.readValue(serverToolDTO.inputSchema(), FunctionSchema.class) : null;
                    return new McpToolRecord(originalToolName, description, mcpConfigRef, new Function(mcpConfigRef + "__" + originalToolName, description, toolSchema));
                }
                catch (JacksonException e) {
                    logger.error("Failed to convert object to McpToolRecord: {}", obj, (Object)e);
                    return null;
                }
            }).filter(Objects::nonNull).toList();
        }

        private void collect(Collection<McpToolRecord> mcpTools) {
            mcpTools.forEach(mcpTool -> this.discoveredTools.put(mcpTool.getName(), (McpToolRecord)mcpTool));
            if (this.countDown.decrementAndGet() <= 0) {
                this.future.complete(this.discoveredTools);
            }
        }

        private void handleToolParsingException(AtomicInteger countDown, Throwable t, String mcpConfigRef) {
            this.future.completeExceptionally(new ModuleException("Exception obtaining toolList from MCP client config %s: %s".formatted(mcpConfigRef, t.getMessage()), (ErrorTypeDefinition)InferenceErrorType.MCP_SERVER_ERROR, t));
            countDown.set(-1);
        }
    }
}

