/*
 * (c) 2025 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file.
 */
package com.mulesoft.modules.agent.conductor.internal.tool.a2a;

import static com.mulesoft.modules.agent.conductor.internal.error.ConductorErrorTypes.TOOL_ERROR;

import static java.util.concurrent.CompletableFuture.failedFuture;

import static org.mule.runtime.api.metadata.DataType.JSON_STRING;
import static org.mule.runtime.api.metadata.DataType.STRING;
import static org.mule.runtime.api.metadata.DataType.fromObject;
import static org.mule.runtime.api.metadata.DataType.fromType;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.runtime.api.lifecycle.Initialisable;
import org.mule.runtime.api.lifecycle.InitialisationException;
import org.mule.runtime.api.metadata.DataType;
import org.mule.runtime.api.metadata.TypedValue;
import org.mule.runtime.api.scheduler.SchedulerService;
import org.mule.runtime.core.api.el.ExpressionManager;
import org.mule.runtime.core.api.util.func.CheckedFunction;
import org.mule.runtime.extension.api.client.ExtensionsClient;
import org.mule.runtime.extension.api.exception.ModuleException;

import com.mulesoft.modules.agent.conductor.api.model.a2a.A2AClient;
import com.mulesoft.modules.agent.conductor.internal.serializer.DwConverter;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolHandler;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolRequest;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolResponse;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolType;

import java.io.ByteArrayInputStream;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;

import javax.inject.Inject;

import org.slf4j.Logger;

/**
 * Service responsible for managing A2A tools in the Agent Conductor. This service discovers agent capabilities and creates tool
 * handlers for A2A tools.
 */
public class A2AToolService implements Initialisable {

  private static final Logger LOGGER = getLogger(A2AToolService.class);
  private static final String A2A = "A2A";
  private static final DataType AGENT_SUMMARY_DATA_TYPE = fromType(AgentSummary.class);
  private static final DataType A2A_TOOL_RESPONSE_DATA_TYPE = fromType(A2AToolResponse.class);

  // Cache for agent data to avoid repeated calculations
  private final Map<String, AgentSummary> agentDataCache = new ConcurrentHashMap<>();

  @Inject
  private ExtensionsClient extensionsClient;

  @Inject
  private SchedulerService schedulerService;

  @Inject
  private ExpressionManager expressionManager;

  private DwConverter a2aMessageWriter;
  private DwConverter a2aToolResponseReader;
  private DwConverter agentSummaryWriter;

  @Override
  public void initialise() throws InitialisationException {
    // Initialize DataWeave converter for A2A send message request building
    a2aMessageWriter = new DwConverter(expressionManager, """
        %dw 2.0
        output application/json
        ---
        {
          message: {
              role: "agent",
              parts: [{
                  "type": "text",
                   text: payload.userInput
              }],
              messageId: uuid(),
              (referenceTaskIds: [payload.taskId]) if (payload.taskId != null),
              (contextId: payload.contextId) if (payload.contextId != null)
          }
        }
        """, (value, builder) -> builder.addBinding("payload", new TypedValue(value, fromObject(value))));

    // Initialize DataWeave converter for parsing A2A response
    a2aToolResponseReader = new DwConverter(expressionManager, """
        %dw 2.0
        output application/java
        ---
         {
            result: payload.artifacts
               flatMap((artifact) ->
                 artifact.parts filter ($.'type' == "text") map ((part) -> part.text)
               )
               joinBy "\n\n",
            status: payload.status.state,
            (taskId: payload.id) if (payload.id?),
            (contextId: payload.contextId) if (payload.contextId?)

         }
          """, (value, builder) -> builder.addBinding("payload", new TypedValue(value, JSON_STRING)));

    // Initialize DataWeave converter for parsing agent card JSON into an AgentSummary
    agentSummaryWriter = new DwConverter(expressionManager, """
        %dw 2.0
        output application/java
        var parsedPayload = read(payload, "application/json")
        var baseDescription = parsedPayload.description default ""
        var skills = parsedPayload.skills default []
        var skillsSummary = "\n\nAvailable skills:\n" ++ ((skills map ("• " ++ $.name ++ ": " ++ $.description)) joinBy "\n")
        var inputModes = parsedPayload.defaultInputModes default []
        var outputModes = parsedPayload.defaultOutputModes default []
        var inputDescription = "Accepts input in formats: " ++ (inputModes joinBy ", ")
        var outputDescription = "Returns output in formats: " ++ (outputModes joinBy ", ")
        ---
        {
            name: parsedPayload.name,
            description: baseDescription ++ skillsSummary,
            inputDescription: inputDescription,
            outputDescription: outputDescription
        }
        """, (value, builder) -> builder.addBinding("payload", new TypedValue(value, STRING)));
  }

  /**
   * Retrieves A2A servers and creates tool handlers for each A2A server provided.
   *
   * @param a2aClients List of A2A client configurations
   * @return CompletableFuture containing a map of tool handlers
   */
  public CompletableFuture<Map<String, ToolHandler>> getA2AToolHandlers(List<A2AClient> a2aClients) {
    if (null == a2aClients || a2aClients.isEmpty()) {
      return CompletableFuture.completedFuture(Map.of());
    }

    Map<String, ToolHandler> tools = new ConcurrentHashMap<>();
    AtomicInteger countDown = new AtomicInteger(a2aClients.size());
    CompletableFuture<Map<String, ToolHandler>> toolHandlersFuture = new CompletableFuture<>();

    try {
      for (A2AClient a2aClient : a2aClients) {
        final String a2aConfigRef = a2aClient.getA2AClientConfigRef();
        var agentSummary = agentDataCache.get(a2aConfigRef);
        if (agentSummary != null) {
          createToolHandler(agentSummary, a2aClient, tools, countDown, toolHandlersFuture);
        } else {
          schedulerService.ioScheduler().submit(() -> {
            try {
              var calculatedAgentSummary =
                  agentDataCache.computeIfAbsent(a2aConfigRef, key -> fetchAgentData(key, countDown, toolHandlersFuture));
              createToolHandler(calculatedAgentSummary, a2aClient, tools, countDown, toolHandlersFuture);
            } catch (Exception e) {
              handleDiscoveryException(countDown, toolHandlersFuture, e, a2aConfigRef);
            }
          });
        }
      }
    } catch (Exception e) {
      toolHandlersFuture.completeExceptionally(e);
    }

    return toolHandlersFuture;
  }

  /**
   * Fetches agent data from A2A client and parses the agent card.
   * 
   * @param a2aConfigRef       The A2A configuration reference
   * @param countDown          Atomic counter for tracking completion
   * @param toolHandlersFuture Future to complete on error
   * @return AgentSummary object or null if error occurs
   */
  private AgentSummary fetchAgentData(String a2aConfigRef, AtomicInteger countDown,
                                      CompletableFuture<Map<String, ToolHandler>> toolHandlersFuture) {
    LOGGER.debug("Cache miss for config: {}. Attempting to fetch agent card...", a2aConfigRef);

    try {
      var result = extensionsClient.execute(A2A, "getCard", params -> params.withConfigRef(a2aConfigRef)).get();
      Object output = result.getOutput();
      String cardJson = (String) output;

      AgentSummary parsedAgentData =
          agentSummaryWriter.evaluate(new TypedValue(cardJson, STRING), AGENT_SUMMARY_DATA_TYPE);

      LOGGER.debug("Successfully parsed AgentSummary for config {}: name={}, description={}",
                   a2aConfigRef, parsedAgentData.getName(), parsedAgentData.getDescription());

      return parsedAgentData;
    } catch (ExecutionException e) {
      handleDiscoveryException(countDown, toolHandlersFuture,
                               new ModuleException("Failed to fetch agent card for config '%s'".formatted(a2aConfigRef),
                                                   TOOL_ERROR, e.getCause()),
                               a2aConfigRef);
      return null;
    } catch (InterruptedException e) {
      handleDiscoveryException(countDown, toolHandlersFuture,
                               new ModuleException("Fetch agent card for config %s was interrupted.".formatted(a2aConfigRef),
                                                   TOOL_ERROR, e.getCause()),
                               a2aConfigRef);
      return null;
    }
  }

  private void createToolHandler(AgentSummary agentData,
                                 A2AClient a2aClient,
                                 Map<String, ToolHandler> tools,
                                 AtomicInteger countDown,
                                 CompletableFuture<Map<String, ToolHandler>> toolHandlersFuture) {
    final String a2aConfigRef = a2aClient.getA2AClientConfigRef();
    ToolHandler toolHandler = new ToolHandler(
                                              agentData.getName(),
                                              agentData.getDescription(),
                                              agentData.getInputDescription(),
                                              agentData.getOutputDescription(),
                                              ToolType.A2A,
                                              createHandler(a2aConfigRef));

    tools.put(agentData.getName(), toolHandler);

    if (countDown.decrementAndGet() <= 0) {
      toolHandlersFuture.complete(tools);
    }
  }

  public CheckedFunction<ToolRequest, CompletableFuture<ToolResponse>> createHandler(String configRef) {
    return request -> {
      try {
        String a2AMessage = writeA2AMessage(request);

        return extensionsClient.execute(A2A, "sendMessage", params -> params.withConfigRef(configRef)
            .withParameter("message", new ByteArrayInputStream(a2AMessage.getBytes())))
            .thenApply(result -> a2aToolResponseReader.evaluate(result.getOutput(), A2A_TOOL_RESPONSE_DATA_TYPE));
      } catch (Exception e) {
        return failedFuture(new ModuleException("Failed to invoke Agent " + request.getToolName(), TOOL_ERROR, e));
      }
    };
  }

  private void handleDiscoveryException(AtomicInteger countDown, CompletableFuture<?> toolHandlersFuture, Throwable t,
                                        String a2aConfigRef) {
    toolHandlersFuture.completeExceptionally(
                                             new ModuleException(
                                                                 "Exception discovering agent from A2A Client config %s: %s"
                                                                     .formatted(a2aConfigRef, t.getMessage()),
                                                                 TOOL_ERROR, t));
    countDown.set(-1);
  }

  /**
   * Builds A2A message parameters using DataWeave for clean and maintainable JSON generation.
   * 
   * @param request a {@link ToolRequest}
   * @return JSON string containing the A2A message
   */
  private String writeA2AMessage(ToolRequest request) {
    final String toolName = request.getToolName();
    final String toolInput = request.getToolInput();
    final var toolContext = request.getTaskContext().getAgentToolContext(toolName);
    final String contextId = toolContext.getContextId();
    final String taskId = toolContext.getTaskId();

    LOGGER.debug("Building A2A message for agent {}: userInput={}, contextId={} taskId={}",
                 toolName, toolInput, contextId, taskId);

    A2AMessage message = new A2AMessage(toolInput, contextId, taskId);
    return a2aMessageWriter.evaluateAsString(message, JSON_STRING);
  }

  public AgentSummary getSummaryForAgent(String agentName) {
    return agentDataCache.values().stream()
        .filter(summary -> summary != null && agentName.equals(summary.getName()))
        .findFirst()
        .orElse(null);
  }

  public String getA2AClientConfigRef(String agentName) {
    return agentDataCache.entrySet().stream()
        .filter(entry -> entry.getValue() != null && agentName.equals(entry.getValue().getName()))
        .map(Map.Entry::getKey)
        .findFirst()
        .orElse(null);
  }


}
