/*
 * (c) 2025 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file.
 */
package com.mulesoft.modules.agent.broker.internal.tool.a2a;

import static com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes.TOOL_ERROR;
import static com.mulesoft.modules.agent.broker.internal.log.SecureLogger.secureLogger;
import static com.mulesoft.modules.agent.broker.internal.tool.a2a.A2AUtils.collectParts;
import static com.mulesoft.modules.agent.broker.internal.util.ToolUtils.generateToolId;
import static com.mulesoft.modules.agent.broker.internal.util.ToolUtils.newToolCache;

import static java.util.UUID.randomUUID;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.failedFuture;

import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_ABSENT;
import static io.a2a.spec.Message.Role.AGENT;
import static io.a2a.spec.Message.Role.USER;
import static io.a2a.spec.TaskState.COMPLETED;
import static io.a2a.spec.TaskState.INPUT_REQUIRED;
import static io.a2a.util.Utils.OBJECT_MAPPER;
import static org.mule.runtime.api.metadata.DataType.STRING;
import static org.mule.runtime.api.metadata.DataType.fromType;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.runtime.api.lifecycle.Initialisable;
import org.mule.runtime.api.lifecycle.InitialisationException;
import org.mule.runtime.api.metadata.DataType;
import org.mule.runtime.api.metadata.MediaType;
import org.mule.runtime.api.metadata.TypedValue;
import org.mule.runtime.api.scheduler.SchedulerService;
import org.mule.runtime.core.api.el.ExpressionManager;
import org.mule.runtime.core.api.util.IOUtils;
import org.mule.runtime.extension.api.client.ExtensionsClient;
import org.mule.runtime.extension.api.exception.ModuleException;
import org.mule.runtime.extension.api.runtime.operation.Result;

import com.mulesoft.modules.agent.broker.api.model.a2a.A2AClient;
import com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes;
import com.mulesoft.modules.agent.broker.internal.serializer.DwConverter;
import com.mulesoft.modules.agent.broker.internal.state.model.Conversation;
import com.mulesoft.modules.agent.broker.internal.state.model.TaskContext;
import com.mulesoft.modules.agent.broker.internal.tool.Tool;
import com.mulesoft.modules.agent.broker.internal.tool.ToolRequest;
import com.mulesoft.modules.agent.broker.internal.tool.ToolResponse;
import com.mulesoft.modules.agent.broker.internal.tool.ToolType;
import com.mulesoft.modules.agent.broker.internal.tool.ToolVisitor;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.time.OffsetDateTime;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;

import javax.inject.Inject;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.github.benmanes.caffeine.cache.Cache;
import io.a2a.spec.Artifact;
import io.a2a.spec.EventKind;
import io.a2a.spec.Message;
import io.a2a.spec.MessageSendParams;
import io.a2a.spec.Part;
import io.a2a.spec.Task;
import io.a2a.spec.TaskStatus;
import io.a2a.spec.TextPart;
import org.slf4j.Logger;

/**
 * Service responsible for managing A2A tools in the Agent Broker. This service discovers agent capabilities and creates tool
 * handlers for A2A tools.
 */
public class A2AService implements Initialisable {

  private static final Logger LOGGER = getLogger(A2AService.class);
  private static final String A2A = "A2A";
  private static final DataType AGENT_SUMMARY_DATA_TYPE = fromType(AgentSummary.class);

  // Cache for agent data to avoid repeated calculations
  private final Cache<String, Tool> tools = newToolCache();

  @Inject
  private SchedulerService schedulerService;

  @Inject
  private ExpressionManager expressionManager;

  private DwConverter agentSummaryWriter;

  @Override
  public void initialise() throws InitialisationException {
    // Initialize DataWeave converter for parsing agent card JSON into an AgentSummary
    agentSummaryWriter = new DwConverter(expressionManager, """
        %dw 2.0
        output application/java
        var parsedPayload = read(payload, "application/json")
        var baseDescription = parsedPayload.description default ""
        var skills = parsedPayload.skills default []
        var skillsSummary = "\n\nAvailable skills:\n" ++ ((skills map ("• " ++ $.name ++ ": " ++ $.description)) joinBy "\n")
        var inputModes = parsedPayload.defaultInputModes default []
        var outputModes = parsedPayload.defaultOutputModes default []
        var inputDescription = "Accepts input in formats: " ++ (inputModes joinBy ", ")
        var outputDescription = "Returns output in formats: " ++ (outputModes joinBy ", ")
        ---
        {
            name: parsedPayload.name,
            description: baseDescription ++ skillsSummary,
            inputDescription: inputDescription,
            outputDescription: outputDescription
        }
        """, (value, builder) -> builder.addBinding("payload", new TypedValue(value, STRING)));
  }

  /**
   * Retrieves A2A servers and creates tools for each A2A server provided.
   *
   * @param a2aClients       List of A2A client configurations
   * @param extensionsClient The ExtensionClient to hit the A2A connector
   * @return CompletableFuture containing a map of tool handlers
   */
  public CompletableFuture<Map<String, Tool>> getTools(List<A2AClient> a2aClients,
                                                       ExtensionsClient extensionsClient) {
    if (null == a2aClients || a2aClients.isEmpty()) {
      return completedFuture(Map.of());
    }

    Map<String, Tool> collectedTools = new ConcurrentHashMap<>();
    AtomicInteger countDown = new AtomicInteger(a2aClients.size());
    CompletableFuture<Map<String, Tool>> future = new CompletableFuture<>();

    try {
      for (A2AClient a2aClient : a2aClients) {
        final String a2aConfigRef = a2aClient.getA2AClientConfigRef();
        var tool = tools.getIfPresent(a2aConfigRef);
        if (tool != null) {
          collect(tool, collectedTools, countDown, future);
        } else {
          schedulerService.ioScheduler().submit(() -> {
            try {
              collect(tools.get(a2aConfigRef, key -> {
                var agentSummary = fetchAgentSummary(key, countDown, future, extensionsClient);
                return new A2ATool(
                                   generateToolId(a2aConfigRef, agentSummary.getName()),
                                   agentSummary.getName(),
                                   agentSummary.getDescription(),
                                   agentSummary.getInputDescription(),
                                   agentSummary.getOutputDescription(),
                                   a2aConfigRef);
              }), collectedTools, countDown, future);

            } catch (Exception e) {
              handleDiscoveryException(countDown, future, e, a2aConfigRef);
            }
          });
        }
      }
    } catch (Exception e) {
      future.completeExceptionally(e);
    }

    return future;
  }

  public Result<String, Void> completedJsonResult(Conversation conversation, TaskContext taskContext, String output) {
    var builder = new Task.Builder()
        .id(taskContext.getTaskId())
        .contextId(conversation.getId())
        .status(new TaskStatus(COMPLETED))
        .artifacts(List.of(new Artifact.Builder()
            .artifactId(UUID.randomUUID().toString())
            .parts(List.of(new TextPart(output)))
            .build()));

    return asJsonResult(builder.build());
  }

  public Result<String, Void> inputRequiredJsonResult(Conversation conversation,
                                                      TaskContext taskContext,
                                                      List<Part<?>> messageParts,
                                                      List<Artifact> artifacts) {
    var builder = new Task.Builder()
        .id(taskContext.getTaskId())
        .contextId(conversation.getId())
        .status(new TaskStatus(INPUT_REQUIRED, new Message.Builder()
            .role(AGENT)
            .parts(messageParts)
            .taskId(taskContext.getTaskId())
            .messageId(randomUUID().toString())
            .build(), OffsetDateTime.now()))
        .artifacts(artifacts);

    return asJsonResult(builder.build());
  }

  public Result<String, Void> asJsonResult(Task task) {
    return Result.<String, Void>builder()
        .output(toJson(task))
        .mediaType(MediaType.APPLICATION_JSON)
        .build();
  }

  public <T> T unmarshall(InputStream json, Class<T> type) {
    String consumedPayload = null;
    try {
      // MG says: workaround for what seems to be a bug in mule. I should be able to receive this as a string in the operation
      consumedPayload = IOUtils.toString(json);
      return OBJECT_MAPPER.readValue(consumedPayload, type);
    } catch (JsonProcessingException e) {
      secureLogger().debug("Received A2A payload does not comply with the schema. Payload was:\n{}", consumedPayload, e);
      throw new ModuleException("Received A2A payload does not comply with the schema", BrokerErrorTypes.A2A, e);
    } catch (Exception e) {
      throw new ModuleException("Error unmarshalling A2A payload", BrokerErrorTypes.A2A, e);
    }
  }

  /**
   * Fetches agent data from A2A client and parses the agent card.
   *
   * @param a2aConfigRef       The A2A configuration reference
   * @param countDown          Atomic counter for tracking completion
   * @param toolHandlersFuture Future to complete on error
   * @return AgentSummary object or null if error occurs
   */
  private AgentSummary fetchAgentSummary(String a2aConfigRef,
                                         AtomicInteger countDown,
                                         CompletableFuture<Map<String, Tool>> toolHandlersFuture,
                                         ExtensionsClient extensionsClient) {
    LOGGER.debug("Cache miss for config: {}. Attempting to fetch agent card...", a2aConfigRef);

    try {
      var result = extensionsClient.execute(A2A, "getCard", params -> params.withConfigRef(a2aConfigRef)).get();
      Object output = result.getOutput();
      String cardJson = (String) output;

      AgentSummary parsedAgentData =
          agentSummaryWriter.evaluate(new TypedValue(cardJson, STRING), AGENT_SUMMARY_DATA_TYPE);

      LOGGER.debug("Successfully parsed AgentSummary for config {}: name={}, description={}",
                   a2aConfigRef, parsedAgentData.getName(), parsedAgentData.getDescription());

      return parsedAgentData;
    } catch (ExecutionException e) {
      handleDiscoveryException(countDown, toolHandlersFuture,
                               new ModuleException("Failed to fetch agent card for config '%s'".formatted(a2aConfigRef),
                                                   TOOL_ERROR, e.getCause()),
                               a2aConfigRef);
      return null;
    } catch (InterruptedException e) {
      handleDiscoveryException(countDown, toolHandlersFuture,
                               new ModuleException("Fetch agent card for config %s was interrupted.".formatted(a2aConfigRef),
                                                   TOOL_ERROR, e.getCause()),
                               a2aConfigRef);
      return null;
    }
  }

  private void collect(Tool tool,
                       Map<String, Tool> tools,
                       AtomicInteger countDown,
                       CompletableFuture<Map<String, Tool>> future) {

    tools.put(tool.getId(), tool);

    if (countDown.decrementAndGet() <= 0) {
      future.complete(tools);
    }
  }

  public class A2ATool extends Tool {

    private final String configRef;

    private A2ATool(String id, String name, String description, String input, String output, String configRef) {
      super(id, name, description, input, output);
      this.configRef = configRef;
    }

    @Override
    public CompletableFuture<ToolResponse> execute(ToolRequest request, ExtensionsClient extensionsClient) {
      try {
        String a2AMessage = toA2ARequest(request);
        secureLogger().debug("Sending A2A message to agent {}:\n {}", request.selection().getToolId(), a2AMessage);
        return extensionsClient.execute(A2AService.A2A, "sendMessage", params -> params.withConfigRef(configRef)
            .withParameter("message", new ByteArrayInputStream(a2AMessage.getBytes()))
            .withParameter("Request", "additionalProperties", collectHeaders(request)))
            .thenApply(result -> parseResponse(request, result.getOutput()));
      } catch (Exception e) {
        return failedFuture(new ModuleException("Failed to invoke Agent " + request.selection().getToolId(), TOOL_ERROR, e));
      }
    }

    private ToolResponse parseResponse(ToolRequest request, Object rawResponse) {
      if (rawResponse == null) {
        throw new ModuleException("Tool did not return any response", TOOL_ERROR);
      }

      if (!(rawResponse instanceof String)) {
        throw new ModuleException("Tool response is of unexpected type: " + rawResponse.getClass().getName(), TOOL_ERROR);
      }

      try {
        var response = OBJECT_MAPPER.readValue((String) rawResponse, EventKind.class);

        A2AToolOutput output;

        if (response instanceof Task task) {
          output = new A2AToolOutput(collectParts(task.getStatus()), task.getArtifacts());
          return new A2AToolResponse(toJson(output), request.selection(), task.getId(), task.getContextId(),
                                     task.getStatus().state(), output.messageParts(), task.getArtifacts());
        } else if (response instanceof Message message) {
          output = new A2AToolOutput(collectParts(message), null);
          return new A2AToolResponse(toJson(output), request.selection(), message.getTaskId(), message.getContextId(), COMPLETED,
                                     output.messageParts(), List.of());
        } else {
          throw new ModuleException("Tool response is of unexpected type: " + response.getKind(), TOOL_ERROR);
        }
      } catch (Exception e) {
        throw new ModuleException("Failed to parse tool response: " + e.getMessage(), TOOL_ERROR, e);
      }
    }

    private String toJson(A2AToolOutput output) {
      try {
        return OBJECT_MAPPER.writeValueAsString(output);
      } catch (Exception e) {
        throw new ModuleException("Failed to serialize tool response", TOOL_ERROR, e);
      }
    }

    @JsonInclude(NON_ABSENT)
    @JsonIgnoreProperties(ignoreUnknown = true)
    public record A2AToolOutput(List<Part<?>> messageParts, List<Artifact> artifacts) {
    }

    @Override
    public void accept(ToolVisitor visitor) {
      visitor.visit(this);
    }

    @Override
    public ToolType getToolType() {
      return ToolType.A2A;
    }
  }

  private void handleDiscoveryException(AtomicInteger countDown, CompletableFuture<?> toolHandlersFuture, Throwable t,
                                        String a2aConfigRef) {
    toolHandlersFuture.completeExceptionally(
                                             new ModuleException(
                                                                 "Exception discovering agent from A2A Client config %s: %s"
                                                                     .formatted(a2aConfigRef, t.getMessage()),
                                                                 TOOL_ERROR, t));
    countDown.set(-1);
  }

  /**
   * Builds A2A message in Json format
   *
   * @param request a {@link ToolRequest}
   * @return JSON string containing the A2A message
   */
  private String toA2ARequest(ToolRequest request) {
    final var toolContext = request.taskContext().getAgentToolContext(request.selection().getToolId());

    Message.Builder builder = new Message.Builder()
        .messageId(randomUUID().toString())
        .role(USER)
        .parts(new TextPart(request.selection().getInput()));

    builder.contextId(toolContext.getContextId());
    if (toolContext.getTaskId() != null) {
      if (toolContext.getTaskState().isFinal()) {
        builder.referenceTaskIds(List.of(toolContext.getTaskId()));
      } else {
        builder.taskId(toolContext.getTaskId());
      }
    }

    return toJson(new MessageSendParams.Builder().message(builder.build()).build());
  }

  private String toJson(Object value) {
    try {
      return OBJECT_MAPPER.writeValueAsString(value);
    } catch (Exception e) {
      throw new ModuleException("Error marshalling A2A payload", BrokerErrorTypes.A2A, e);
    }
  }
}
