/*
 * (c) 2025 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file.
 */
package com.mulesoft.modules.agent.broker.internal.operation.loop;

import static com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes.A2A;
import static com.mulesoft.modules.agent.broker.internal.tracing.TracingUtils.traceBroker;
import static com.mulesoft.modules.agent.broker.internal.util.ExceptionUtils.unwrap;

import static java.util.concurrent.CompletableFuture.failedFuture;

import static org.mule.runtime.api.meta.ExpressionSupport.REQUIRED;
import static org.mule.runtime.api.meta.model.operation.ExecutionType.CPU_LITE;
import static org.mule.runtime.extension.api.annotation.param.MediaType.APPLICATION_JSON;
import static org.mule.sdk.api.annotation.route.ChainExecutionOccurrence.MULTIPLE_OR_NONE;

import org.mule.runtime.api.scheduler.SchedulerService;
import org.mule.runtime.core.api.el.ExpressionManager;
import org.mule.runtime.extension.api.annotation.Alias;
import org.mule.runtime.extension.api.annotation.Expression;
import org.mule.runtime.extension.api.annotation.dsl.xml.ParameterDsl;
import org.mule.runtime.extension.api.annotation.error.Throws;
import org.mule.runtime.extension.api.annotation.execution.Execution;
import org.mule.runtime.extension.api.annotation.param.Config;
import org.mule.runtime.extension.api.annotation.param.ConfigOverride;
import org.mule.runtime.extension.api.annotation.param.Connection;
import org.mule.runtime.extension.api.annotation.param.Content;
import org.mule.runtime.extension.api.annotation.param.MediaType;
import org.mule.runtime.extension.api.annotation.param.NullSafe;
import org.mule.runtime.extension.api.annotation.param.Optional;
import org.mule.runtime.extension.api.client.ExtensionsClient;
import org.mule.runtime.extension.api.exception.ModuleException;
import org.mule.runtime.extension.api.runtime.process.CompletionCallback;
import org.mule.sdk.api.annotation.metadata.ChainInputResolver;
import org.mule.sdk.api.annotation.route.ExecutionOccurrence;
import org.mule.sdk.api.runtime.source.DistributedTraceContextManager;

import com.mulesoft.modules.agent.broker.api.model.a2a.A2AClient;
import com.mulesoft.modules.agent.broker.api.model.mcp.McpServer;
import com.mulesoft.modules.agent.broker.api.model.tool.CustomToolRoute;
import com.mulesoft.modules.agent.broker.internal.datasense.ToolInputResolver;
import com.mulesoft.modules.agent.broker.internal.extension.AgentsBroker;
import com.mulesoft.modules.agent.broker.internal.extension.connection.LLMClient;
import com.mulesoft.modules.agent.broker.internal.prompt.PromptBuilder;
import com.mulesoft.modules.agent.broker.internal.serializer.SerializationService;
import com.mulesoft.modules.agent.broker.internal.state.ConversationService;
import com.mulesoft.modules.agent.broker.internal.tool.Tool;
import com.mulesoft.modules.agent.broker.internal.tool.ToolRequest;
import com.mulesoft.modules.agent.broker.internal.tool.ToolResponse;
import com.mulesoft.modules.agent.broker.internal.tool.a2a.A2AService;
import com.mulesoft.modules.agent.broker.internal.tool.custom.CustomToolResponse;
import com.mulesoft.modules.agent.broker.internal.tool.mcp.McpService;

import java.io.InputStream;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import javax.inject.Inject;

import io.a2a.spec.Message;
import io.a2a.spec.MessageSendParams;

public class LoopOperation {

  @Inject
  private SchedulerService schedulerService;

  @Inject
  private ExpressionManager expressionManager;

  @Inject
  private SerializationService serializer;

  @Inject
  private McpService mcpService;

  @Inject
  private A2AService a2AService;

  @Inject
  private ConversationService conversationService;


  @Execution(CPU_LITE)
  @Throws(LoopErrorTypeProvider.class)
  @MediaType(value = APPLICATION_JSON, strict = false)
  public void agentLoop(
                        @Config AgentsBroker config,
                        @Connection LLMClient llmClient,
                        @Optional @Expression(REQUIRED) String taskId,
                        @Optional @Expression(REQUIRED) String contextId,
                        @Content(primary = true) InputStream a2aMessageSendParams,
                        @Optional String apiInstanceId,
                        @Content String prompt,
                        @Optional @Content String instructions,
                        @Optional @Content String groundings,
                        @ParameterDsl(allowReferences = false) @Optional @NullSafe List<McpServer> mcpServers,
                        @ParameterDsl(
                            allowReferences = false) @Alias("a2aClients") @Optional @NullSafe List<A2AClient> a2aClients,
                        @Optional @NullSafe @ExecutionOccurrence(MULTIPLE_OR_NONE) @ChainInputResolver(ToolInputResolver.class) List<CustomToolRoute> tools,
                        @ConfigOverride Integer maxNumberOfLoops,
                        @ConfigOverride Integer maxConsecutiveErrors,
                        DistributedTraceContextManager traceContextManager,
                        ExtensionsClient extensionsClient,
                        CompletionCallback<String, Void> completionCallback) {

    traceBroker(config, apiInstanceId, traceContextManager);

    final Message message;
    try {
      message = a2AService.unmarshall(a2aMessageSendParams, MessageSendParams.class).message();
    } catch (Exception e) {
      completionCallback.error(new ModuleException("Could not parse A2A request", A2A, e));
      return;
    }

    collectTools(tools, mcpServers, a2aClients, extensionsClient).whenComplete((collectedTools, e) -> {
      if (e != null) {
        completionCallback.error(unwrap(e));
        return;
      }

      PromptBuilder promptBuilder = new PromptBuilder(expressionManager)
          .setUserPrompt(prompt)
          .setUserInstructions(instructions)
          .setGroundings(groundings)
          .setTools(collectedTools.values())
          .setMaxLoops(maxNumberOfLoops);

      Loop loop = new Loop(promptBuilder,
                           llmClient,
                           collectedTools,
                           taskId,
                           contextId,
                           message,
                           apiInstanceId,
                           maxNumberOfLoops,
                           maxConsecutiveErrors,
                           conversationService,
                           a2AService,
                           config,
                           extensionsClient,
                           schedulerService,
                           completionCallback);

      schedulerService.ioScheduler().submit(loop::start);
    });
  }

  private class CustomTool extends Tool {

    private final CustomToolRoute toolRoute;

    public CustomTool(String name, String description, String input, String output, CustomToolRoute toolRoute) {
      super(name, name, description, input, output);
      this.toolRoute = toolRoute;
    }

    @Override
    public CompletableFuture<ToolResponse> execute(ToolRequest request, ExtensionsClient extensionsClient) {
      CompletableFuture<ToolResponse> future = new CompletableFuture<>();
      toolRoute.getChain().process(request, null,
                                   result -> future.complete(new CustomToolResponse(serializer.asString(result))),
                                   (t, r) -> future.completeExceptionally(t));

      return future;
    }
  }

  private CompletableFuture<Map<String, Tool>> collectTools(List<CustomToolRoute> routeTools,
                                                            List<McpServer> mcpServers,
                                                            List<A2AClient> a2aClients,
                                                            ExtensionsClient extensionsClient) {
    try {
      final Map<String, Tool> tools = new LinkedHashMap<>();

      if (routeTools != null) {
        for (var tool : routeTools) {
          tools.put(tool.getName(),
                    new CustomTool(tool.getName(), tool.getDescription(), tool.getInput(), tool.getOutput(), tool));
        }
      }

      // Start both MCP and A2A tool collection in parallel
      CompletableFuture<Map<String, Tool>> mcpFuture = mcpService.getTools(mcpServers, extensionsClient);
      CompletableFuture<Map<String, Tool>> a2aFuture = a2AService.getTools(a2aClients, extensionsClient);

      // Combine both futures in parallel and merge results
      return mcpFuture.thenCombine(a2aFuture, (mcpTools, a2aTools) -> {

        // Merge all tool handlers
        tools.putAll(mcpTools);
        tools.putAll(a2aTools);

        return tools;
      });
    } catch (Exception e) {
      return failedFuture(unwrap(e));
    }
  }
}
