/*
 * (c) 2025 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file.
 */
package com.mulesoft.modules.agent.conductor.internal.tool.mcp;

import static com.mulesoft.modules.agent.conductor.internal.error.ConductorErrorTypes.TOOL_ERROR;

import static java.util.concurrent.CompletableFuture.completedFuture;

import static org.mule.runtime.api.metadata.DataType.JSON_STRING;
import static org.mule.runtime.api.metadata.DataType.fromObject;
import static org.mule.runtime.api.metadata.DataType.fromType;

import org.mule.runtime.api.lifecycle.Initialisable;
import org.mule.runtime.api.lifecycle.InitialisationException;
import org.mule.runtime.api.metadata.DataType;
import org.mule.runtime.api.metadata.TypedValue;
import org.mule.runtime.api.scheduler.SchedulerService;
import org.mule.runtime.core.api.el.ExpressionManager;
import org.mule.runtime.core.api.util.func.CheckedFunction;
import org.mule.runtime.extension.api.client.ExtensionsClient;
import org.mule.runtime.extension.api.exception.ModuleException;
import org.mule.runtime.extension.api.runtime.operation.Result;

import com.mulesoft.modules.agent.conductor.api.model.mcp.McpServer;
import com.mulesoft.modules.agent.conductor.api.model.mcp.ToolFilter;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolRequest;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolType;
import com.mulesoft.modules.agent.conductor.internal.serializer.DwConverter;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolHandler;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolResponse;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import javax.inject.Inject;

public class McpService implements Initialisable {

  private static final DataType MAP_DATA_TYPE = fromType(Map.class);
  private static final DataType MCP_TOOL_INFO_DATA_TYPE = fromType(McpToolInfo.class);
  private static final String MCP = "MCP";

  @Inject
  private ExtensionsClient extensionsClient;

  @Inject
  private ExpressionManager expressionManager;

  @Inject
  private SchedulerService schedulerService;

  private final Map<String, List<McpToolInfo>> cache = new ConcurrentHashMap<>();

  private DwConverter mapWriter;
  private DwConverter toolInfoReader;
  private DwConverter jsonWriter;

  @Override
  public void initialise() throws InitialisationException {
    mapWriter = new DwConverter(expressionManager, "%dw 2.0 output application/java --- payload",
                                (value, builder) -> builder.addBinding("payload", new TypedValue(value, fromObject(value))));

    toolInfoReader = new DwConverter(expressionManager,
                                     """
                                         %dw 2.0
                                         output application/java
                                         ---
                                           {
                                             toolName: payload.name,
                                             description: payload.description,
                                             "input": payload.inputSchema,
                                             "output": "A standard MCP response, such as Text, Blob or image type. Blobs and image responses will be base64 encoded"
                                           }
                                         """,
                                     (value, builder) -> builder.addBinding("payload", new TypedValue(value, MAP_DATA_TYPE)));

    jsonWriter = new DwConverter(expressionManager, """
        %dw 2.0
        output application/java
        ---
        write(payload, "json") as String
        """, (value, builder) -> builder.addBinding("payload", new TypedValue(value, fromObject(value))));
  }

  public CompletableFuture<Map<String, ToolHandler>> getMcpTools(List<McpServer> mcpServers) {
    if (null == mcpServers || mcpServers.isEmpty()) {
      return completedFuture(Map.of());
    }

    return new McpDiscovery(mcpServers).getTools();
  }

  private class McpDiscovery {

    private final List<McpServer> mcpServers;
    private final Map<String, ToolHandler> tools = new ConcurrentHashMap<>();
    private final AtomicInteger countDown;
    private final CompletableFuture<Map<String, ToolHandler>> future = new CompletableFuture<>();

    private McpDiscovery(List<McpServer> mcpServers) {
      this.mcpServers = mcpServers;
      countDown = new AtomicInteger(mcpServers.size());
    }

    public CompletableFuture<Map<String, ToolHandler>> getTools() {
      try {
        for (McpServer mcpServer : mcpServers) {
          final String mcpConfigRef = mcpServer.getMcpClientConfigRef();
          List<McpToolInfo> mcpTools = cache.get(mcpConfigRef);
          if (mcpTools != null) {
            mcpTools.forEach(this::collect);
          } else {
            schedulerService.ioScheduler().submit(() -> fetchTools(mcpServer));
          }
        }
      } catch (Exception ex) {
        future.completeExceptionally(ex);
      }
      return future;
    }

    private void fetchTools(McpServer mcpServer) {
      final String mcpConfigRef = mcpServer.getMcpClientConfigRef();
      extensionsClient.execute(MCP, "listTools", params -> params.withConfigRef(mcpConfigRef))
          .whenComplete((result, t) -> {
            if (t != null) {
              handleToolParsingException(countDown, future, t, mcpConfigRef);
            } else {
              try {
                List<McpToolInfo> tools = parseTools(mcpServer, result, mcpConfigRef);

                if (!tools.isEmpty()) {
                  tools.forEach(this::collect);
                  cache.put(mcpConfigRef, tools);
                }
              } catch (Exception e) {
                handleToolParsingException(countDown, future, e, mcpConfigRef);
              }
            }
          });
    }

    private List<McpToolInfo> parseTools(McpServer mcpServer, Result<Object, Object> result, String mcpConfigRef) {
      var iterator = (Iterator<Object>) result.getOutput();
      List<McpToolInfo> tools = new ArrayList<>();
      iterator.forEachRemaining(value -> {
        McpToolInfo toolInfo = toolInfoReader.evaluate(value, MCP_TOOL_INFO_DATA_TYPE);
        String toolName = toolInfo.getToolName();

        if (isAllowed(toolName, mcpServer.getToolsFilter())) {
          String toolId = mcpConfigRef + "." + toolName;
          toolInfo.setToolId(toolId);
          toolInfo.setConfigRef(mcpConfigRef);
          tools.add(toolInfo);
        }
      });
      return tools;
    }

    private void collect(McpToolInfo toolInfo) {
      tools.put(toolInfo.getToolId(), new ToolHandler(
                                                      toolInfo.getToolId(),
                                                      toolInfo.getDescription(),
                                                      toolInfo.getInput(),
                                                      toolInfo.getOutput(),
                                                      ToolType.MCP,
                                                      createHandler(toolInfo.getConfigRef(), toolInfo.getToolName())));

      if (countDown.decrementAndGet() <= 0) {
        future.complete(tools);
      }
    }

    private boolean isAllowed(String toolName, ToolFilter filter) {
      List<String> values = filter.getAllowedTools();
      if (values != null && !values.isEmpty()) {
        return values.contains(toolName);
      }

      values = filter.getDisallowedTools();
      if (values != null && !values.isEmpty()) {
        return !values.contains(toolName);
      }

      return true;
    }

    private CheckedFunction<ToolRequest, CompletableFuture<ToolResponse>> createHandler(String configRef, String toolName) {
      return request -> extensionsClient.execute(MCP, "callTool", params -> params.withConfigRef(configRef)
          .withParameter("toolName", toolName)
          .withParameter("arguments",
                         mapWriter.evaluate(new TypedValue<>(request.getToolInput(), JSON_STRING), MAP_DATA_TYPE)))
          .thenApply(result -> new McpToolResponse(jsonWriter.evaluateAsString(result.getOutput(), JSON_STRING)));
    }

    private void handleToolParsingException(AtomicInteger countDown, CompletableFuture<?> future, Throwable t,
                                            String mcpConfigRef) {
      future.completeExceptionally(
                                   new ModuleException(
                                                       "Exception obtaining toolList from MCP client config %s: %s"
                                                           .formatted(mcpConfigRef, t.getMessage()),
                                                       TOOL_ERROR, t));
      countDown.set(-1);
    }
  }
}
