/*
 * (c) 2025 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file.
 */
package com.mulesoft.modules.agent.broker.internal.tool.mcp;

import static com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes.TOOL_ERROR;
import static com.mulesoft.modules.agent.broker.internal.util.ExceptionUtils.unwrap;

import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.failedFuture;

import static org.mule.runtime.api.metadata.DataType.JSON_STRING;
import static org.mule.runtime.api.metadata.DataType.fromObject;
import static org.mule.runtime.api.metadata.DataType.fromType;

import org.mule.runtime.api.lifecycle.Initialisable;
import org.mule.runtime.api.lifecycle.InitialisationException;
import org.mule.runtime.api.metadata.DataType;
import org.mule.runtime.api.metadata.TypedValue;
import org.mule.runtime.api.scheduler.SchedulerService;
import org.mule.runtime.core.api.el.ExpressionManager;
import org.mule.runtime.extension.api.client.ExtensionsClient;
import org.mule.runtime.extension.api.exception.ModuleException;
import org.mule.runtime.extension.api.runtime.operation.Result;

import com.mulesoft.modules.agent.broker.api.model.mcp.McpServer;
import com.mulesoft.modules.agent.broker.api.model.mcp.ToolFilter;
import com.mulesoft.modules.agent.broker.internal.serializer.DwConverter;
import com.mulesoft.modules.agent.broker.internal.tool.Tool;
import com.mulesoft.modules.agent.broker.internal.tool.ToolRequest;
import com.mulesoft.modules.agent.broker.internal.tool.ToolResponse;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import javax.inject.Inject;

public class McpService implements Initialisable {

  private static final DataType MAP_DATA_TYPE = fromType(Map.class);
  private static final DataType MCP_TOOL_INFO_DATA_TYPE = fromType(McpToolInfo.class);
  private static final String MCP = "MCP";

  @Inject
  private ExpressionManager expressionManager;

  @Inject
  private SchedulerService schedulerService;

  private final Map<String, List<Tool>> toolsByServer = new ConcurrentHashMap<>();

  private DwConverter mapWriter;
  private DwConverter toolInfoReader;
  private DwConverter jsonWriter;

  @Override
  public void initialise() throws InitialisationException {
    mapWriter = new DwConverter(expressionManager, "%dw 2.0 output application/java --- payload",
                                (value, builder) -> builder.addBinding("payload", new TypedValue(value, fromObject(value))));

    toolInfoReader = new DwConverter(expressionManager,
                                     """
                                         %dw 2.0
                                         output application/java
                                         ---
                                           {
                                             toolName: payload.name,
                                             description: payload.description,
                                             "input": payload.inputSchema,
                                             "output": "A standard MCP response, such as Text, Blob or image type. Blobs and image responses will be base64 encoded"
                                           }
                                         """,
                                     (value, builder) -> builder.addBinding("payload", new TypedValue(value, MAP_DATA_TYPE)));

    jsonWriter = new DwConverter(expressionManager, """
        %dw 2.0
        output application/java
        ---
        write(payload, "json") as String
        """, (value, builder) -> builder.addBinding("payload", new TypedValue(value, fromObject(value))));
  }

  public CompletableFuture<Map<String, Tool>> getTools(List<McpServer> mcpServers, ExtensionsClient extensionsClient) {
    if (null == mcpServers || mcpServers.isEmpty()) {
      return completedFuture(Map.of());
    }

    return new McpDiscovery(mcpServers, extensionsClient).getDiscoveredTools();
  }

  private class McpDiscovery {

    private final List<McpServer> mcpServers;
    private final Map<String, Tool> discoveredTools = new ConcurrentHashMap<>();
    private final AtomicInteger countDown;
    private final CompletableFuture<Map<String, Tool>> future = new CompletableFuture<>();
    private final ExtensionsClient extensionsClient;

    private McpDiscovery(List<McpServer> mcpServers, ExtensionsClient extensionsClient) {
      this.mcpServers = mcpServers;
      countDown = new AtomicInteger(mcpServers.size());
      this.extensionsClient = extensionsClient;
    }

    public CompletableFuture<Map<String, Tool>> getDiscoveredTools() {
      try {
        for (McpServer mcpServer : mcpServers) {
          final String mcpConfigRef = mcpServer.getMcpClientConfigRef();
          List<Tool> mcpTools = toolsByServer.get(mcpConfigRef);
          if (mcpTools != null) {
            collect(mcpTools);
          } else {
            schedulerService.ioScheduler().submit(() -> fetchTools(mcpServer));
          }
        }
      } catch (Exception ex) {
        future.completeExceptionally(ex);
      }
      return future;
    }

    private void fetchTools(McpServer mcpServer) {
      final String mcpConfigRef = mcpServer.getMcpClientConfigRef();
      extensionsClient.execute(MCP, "listTools", params -> params.withConfigRef(mcpConfigRef))
          .whenComplete((result, t) -> {
            if (t != null) {
              handleToolParsingException(countDown, unwrap(t), mcpConfigRef);
            } else {
              try {
                List<Tool> tools = createTools(mcpServer, result, mcpConfigRef);

                if (!tools.isEmpty()) {
                  collect(tools);
                  toolsByServer.put(mcpConfigRef, tools);
                }
              } catch (Exception e) {
                handleToolParsingException(countDown, e, mcpConfigRef);
              }
            }
          });
    }

    private List<Tool> createTools(McpServer mcpServer, Result<Object, Object> result, String mcpConfigRef) {
      var iterator = (Iterator<Object>) result.getOutput();
      List<Tool> tools = new ArrayList<>();
      iterator.forEachRemaining(value -> {
        McpToolInfo toolInfo = toolInfoReader.evaluate(value, MCP_TOOL_INFO_DATA_TYPE);
        String toolName = toolInfo.getToolName();

        if (isAllowed(toolName, mcpServer.getToolsFilter())) {
          var tool = new McpTool(mcpConfigRef + "." + toolName,
                                 toolName,
                                 toolInfo.getDescription(),
                                 toolInfo.getInput(),
                                 toolInfo.getOutput(),
                                 mcpConfigRef);

          tools.add(tool);
        }
      });
      return tools;
    }

    private void collect(Collection<Tool> tools) {
      tools.forEach(tool -> discoveredTools.put(tool.getId(), tool));

      if (countDown.decrementAndGet() <= 0) {
        future.complete(discoveredTools);
      }
    }

    private class McpTool extends Tool {

      private final String configRef;

      private McpTool(String id, String name, String description, String input, String output, String configRef) {
        super(id, name, description, input, output);
        this.configRef = configRef;
      }

      @Override
      public CompletableFuture<ToolResponse> execute(ToolRequest request, ExtensionsClient extensionsClient) {
        try {
          return extensionsClient.execute(MCP, "callTool", params -> params.withConfigRef(configRef)
              .withParameter("toolName", getName())
              .withParameter("arguments",
                             mapWriter.evaluate(new TypedValue<>(request.getToolInput(), JSON_STRING), MAP_DATA_TYPE))
              .withParameter("Request", "additionalProperties",
                             collectHeaders(request)))
              .thenApply(result -> new McpToolResponse(jsonWriter.evaluateAsString(result.getOutput(), JSON_STRING)));
        } catch (Exception e) {
          return failedFuture(new ModuleException("Failed to invoke MCP tool " + request.getToolName(), TOOL_ERROR, e));
        }
      }
    }



    private boolean isAllowed(String toolName, ToolFilter filter) {
      List<String> values = filter.getAllowedTools();
      if (values != null && !values.isEmpty()) {
        return values.contains(toolName);
      }

      values = filter.getDisallowedTools();
      if (values != null && !values.isEmpty()) {
        return !values.contains(toolName);
      }

      return true;
    }

    private void handleToolParsingException(AtomicInteger countDown, Throwable t, String mcpConfigRef) {
      future.completeExceptionally(
                                   new ModuleException(
                                                       "Exception obtaining toolList from MCP client config %s: %s"
                                                           .formatted(mcpConfigRef, t.getMessage()),
                                                       TOOL_ERROR, t));
      countDown.set(-1);
    }
  }
}
