/*
 * (c) 2025 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file.
 */
package com.mulesoft.modules.agent.conductor.internal.operation.loop;

import static com.mulesoft.modules.agent.conductor.internal.tool.ToolType.CUSTOM;

import static java.util.UUID.randomUUID;
import static java.util.concurrent.CompletableFuture.failedFuture;

import static org.mule.runtime.api.meta.ExpressionSupport.REQUIRED;
import static org.mule.runtime.api.meta.model.operation.ExecutionType.CPU_LITE;
import static org.mule.runtime.core.api.util.StringUtils.isBlank;
import static org.mule.sdk.api.annotation.route.ChainExecutionOccurrence.MULTIPLE_OR_NONE;

import org.mule.runtime.api.scheduler.SchedulerService;
import org.mule.runtime.core.api.el.ExpressionManager;
import org.mule.runtime.core.api.util.func.CheckedFunction;
import org.mule.runtime.extension.api.annotation.Alias;
import org.mule.runtime.extension.api.annotation.Expression;
import org.mule.runtime.extension.api.annotation.dsl.xml.ParameterDsl;
import org.mule.runtime.extension.api.annotation.error.Throws;
import org.mule.runtime.extension.api.annotation.execution.Execution;
import org.mule.runtime.extension.api.annotation.param.Config;
import org.mule.runtime.extension.api.annotation.param.ConfigOverride;
import org.mule.runtime.extension.api.annotation.param.Content;
import org.mule.runtime.extension.api.annotation.param.NullSafe;
import org.mule.runtime.extension.api.annotation.param.Optional;
import org.mule.runtime.extension.api.runtime.process.CompletionCallback;
import org.mule.sdk.api.annotation.metadata.ChainInputResolver;
import org.mule.sdk.api.annotation.route.ExecutionOccurrence;

import com.mulesoft.modules.agent.conductor.api.model.Orchestration;
import com.mulesoft.modules.agent.conductor.api.model.a2a.A2AClient;
import com.mulesoft.modules.agent.conductor.api.model.llm.LLMSettings;
import com.mulesoft.modules.agent.conductor.api.model.mcp.McpServer;
import com.mulesoft.modules.agent.conductor.api.model.tool.Tool;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolRequest;
import com.mulesoft.modules.agent.conductor.internal.datasense.ToolInputResolver;
import com.mulesoft.modules.agent.conductor.internal.extension.AgentConductor;
import com.mulesoft.modules.agent.conductor.internal.llm.client.LLMClientFactory;
import com.mulesoft.modules.agent.conductor.internal.prompt.PromptBuilder;
import com.mulesoft.modules.agent.conductor.internal.serializer.LLMSerializer;
import com.mulesoft.modules.agent.conductor.internal.state.TaskContextService;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolHandler;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolResponse;
import com.mulesoft.modules.agent.conductor.internal.tool.a2a.A2AToolService;
import com.mulesoft.modules.agent.conductor.internal.tool.custom.CustomToolResponse;
import com.mulesoft.modules.agent.conductor.internal.tool.mcp.McpService;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import javax.inject.Inject;

public class LoopOperation {

  @Inject
  private SchedulerService schedulerService;

  @Inject
  private ExpressionManager expressionManager;

  @Inject
  private LLMSerializer serializer;

  @Inject
  private LLMClientFactory llmClientFactory;

  @Inject
  private McpService mcpService;

  @Inject
  private A2AToolService a2aToolService;


  @Inject
  private TaskContextService taskContextService;


  @Execution(CPU_LITE)
  @Throws(LoopErrorTypeProvider.class)
  public void agentLoop(
                        @Config AgentConductor config,
                        @Optional @Expression(REQUIRED) String taskId,
                        @Optional @Expression(REQUIRED) String contextId,
                        @Content(primary = true) String prompt,
                        LLMSettings llm,
                        @Optional @Content String instructions,
                        @Optional @Content String groundings,
                        @ParameterDsl(allowReferences = false) @Optional @NullSafe List<McpServer> mcpServers,
                        @ParameterDsl(
                            allowReferences = false) @Alias("a2aClients") @Optional @NullSafe List<A2AClient> a2aClients,
                        @Optional @NullSafe @ExecutionOccurrence(MULTIPLE_OR_NONE) @ChainInputResolver(ToolInputResolver.class) List<Tool> tools,
                        @ConfigOverride Integer maxNumberOfLoops,
                        @ConfigOverride Integer maxConsecutiveErrors,
                        CompletionCallback<Orchestration, Void> completionCallback) {

    collectTools(tools, mcpServers, a2aClients).whenComplete((toolHandlers, e) -> {
      if (e != null) {
        completionCallback.error(e);
        return;
      }

      PromptBuilder promptBuilder = new PromptBuilder(expressionManager)
          .setUserPrompt(prompt)
          .setUserInstructions(instructions)
          .setGroundings(groundings)
          .setTools(toolHandlers.values())
          .setMaxLoops(maxNumberOfLoops);

      Loop loop = new Loop(promptBuilder,
                           llmClientFactory.getClient(llm),
                           a2aToolService,
                           toolHandlers,
                           nonBlank(taskId),
                           nonBlank(contextId),
                           maxNumberOfLoops,
                           maxConsecutiveErrors,
                           taskContextService,
                           config.getConversationStateObjectStore(),
                           config.getConfigName(),
                           completionCallback);

      schedulerService.ioScheduler().submit(loop::start);
    });
  }

  private CheckedFunction<ToolRequest, CompletableFuture<ToolResponse>> routeToolHandler(Tool tool) {
    return request -> {
      CompletableFuture<ToolResponse> future = new CompletableFuture<>();
      tool.getChain().process(request, null,
                              result -> future.complete(new CustomToolResponse(serializer.asString(result))),
                              (t, r) -> future.completeExceptionally(t));

      return future;
    };
  }

  private CompletableFuture<Map<String, ToolHandler>> collectTools(List<Tool> routeTools, List<McpServer> mcpServers,
                                                                   List<A2AClient> a2aClients) {
    try {
      final Map<String, ToolHandler> toolHandlers = new LinkedHashMap<>();

      if (routeTools != null) {
        for (var tool : routeTools) {
          toolHandlers.put(tool
              .getName(), new ToolHandler(tool.getName(), tool.getDescription(), tool.getInput(), tool.getOutput(),
                                          CUSTOM, routeToolHandler(tool)));
        }
      }

      // Start both MCP and A2A tool collection in parallel
      CompletableFuture<Map<String, ToolHandler>> mcpHandlers = mcpService.getMcpTools(mcpServers);
      CompletableFuture<Map<String, ToolHandler>> a2aHandlers = a2aToolService.getA2AToolHandlers(a2aClients);

      // Combine both futures in parallel and merge results
      return mcpHandlers.thenCombine(a2aHandlers, (mcpTools, a2aTools) -> {

        // Merge all tool handlers
        toolHandlers.putAll(mcpTools);
        toolHandlers.putAll(a2aTools);

        return toolHandlers;
      });
    } catch (Exception e) {
      return failedFuture(e);
    }
  }

  private String nonBlank(String input) {
    return isBlank(input) ? randomUUID().toString() : input;
  }
}
