/*
 * (c) 2025 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file.
 */
package com.mulesoft.modules.agent.broker.internal.operation.loop;

import static com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes.A2A;
import static com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes.LLM_ERROR;
import static com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes.MAX_LOOPS;
import static com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes.REASONING_ERROR;
import static com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes.TOOL_ERROR;
import static com.mulesoft.modules.agent.broker.internal.log.SecureLogger.secureLogger;
import static com.mulesoft.modules.agent.broker.internal.state.model.AdditionalInputRequired.RequesterType.A2A_TOOL;
import static com.mulesoft.modules.agent.broker.internal.state.model.AdditionalInputRequired.RequesterType.LLM;
import static com.mulesoft.modules.agent.broker.internal.util.ConcurrencyUtils.completeAsyncInTCCL;
import static com.mulesoft.modules.agent.broker.internal.util.ExceptionUtils.unwrap;

import static java.util.UUID.randomUUID;

import static io.a2a.spec.Message.Role.AGENT;
import static io.a2a.spec.TaskState.COMPLETED;
import static io.a2a.spec.TaskState.FAILED;
import static io.a2a.spec.TaskState.INPUT_REQUIRED;
import static io.a2a.spec.TaskState.WORKING;
import static org.mule.runtime.api.i18n.I18nMessageFactory.createStaticMessage;
import static org.mule.runtime.core.api.util.StringUtils.isBlank;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.api.scheduler.SchedulerService;
import org.mule.runtime.api.store.ObjectStoreException;
import org.mule.runtime.core.api.util.func.CheckedRunnable;
import org.mule.runtime.extension.api.client.ExtensionsClient;
import org.mule.runtime.extension.api.exception.ModuleException;
import org.mule.runtime.extension.api.runtime.operation.Result;
import org.mule.runtime.extension.api.runtime.process.CompletionCallback;

import com.mulesoft.modules.agent.broker.internal.extension.AgentsBroker;
import com.mulesoft.modules.agent.broker.internal.extension.connection.LLMClient;
import com.mulesoft.modules.agent.broker.internal.llm.LLMOrchestrationRequest;
import com.mulesoft.modules.agent.broker.internal.prompt.PromptBuilder;
import com.mulesoft.modules.agent.broker.internal.state.ConversationService;
import com.mulesoft.modules.agent.broker.internal.state.model.AdditionalInputRequired;
import com.mulesoft.modules.agent.broker.internal.state.model.Conversation;
import com.mulesoft.modules.agent.broker.internal.state.model.ConversationState;
import com.mulesoft.modules.agent.broker.internal.state.model.Iteration;
import com.mulesoft.modules.agent.broker.internal.state.model.LLMOutput;
import com.mulesoft.modules.agent.broker.internal.state.model.TaskContext;
import com.mulesoft.modules.agent.broker.internal.state.model.ToolExecution;
import com.mulesoft.modules.agent.broker.internal.tool.Tool;
import com.mulesoft.modules.agent.broker.internal.tool.ToolRequest;
import com.mulesoft.modules.agent.broker.internal.tool.ToolResponse;
import com.mulesoft.modules.agent.broker.internal.tool.ToolType;
import com.mulesoft.modules.agent.broker.internal.tool.a2a.A2AService;
import com.mulesoft.modules.agent.broker.internal.tool.a2a.A2AToolResponse;

import java.time.OffsetDateTime;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;

import io.a2a.spec.Message;
import io.a2a.spec.Task;
import io.a2a.spec.TaskStatus;
import io.a2a.spec.TextPart;
import org.slf4j.Logger;
import org.slf4j.MDC;

class Loop {

  private static final Logger LOGGER = getLogger(Loop.class);

  private static final String AGENT_MDC_KEY = "agent";
  private static final String ITERATION_MDC_KEY = "iteration";
  private static final String TASK_ID_MDC_KEY = "taskId";
  private static final String CONTEXT_ID_MDC_KEY = "contextId";
  private static final String API_INSTANCE_ID_MDC_KEY = "apiInstanceId";

  private final PromptBuilder promptBuilder;
  private final LLMClient llmClient;
  private final Map<String, Tool> toolExecutors;
  private final String taskId;
  private final String contextId;
  private final Message message;
  private final String apiInstanceId;
  private final int maxLoops;
  private final int maxConsecutiveErrors;
  private final ConversationService conversationService;
  private final A2AService a2aService;
  private final AgentsBroker broker;
  private final ExtensionsClient extensionsClient;
  private final SchedulerService schedulerService;
  private final CompletionCallback<String, Void> completionCallback;


  private int currentLoop = 0;
  private int consecutiveErrors = 0;
  private StringBuilder conversationHistory = new StringBuilder();
  private Conversation conversation;
  private TaskContext taskContext;

  public Loop(PromptBuilder promptBuilder,
              LLMClient llmClient,
              Map<String, Tool> toolExecutors,
              String taskId,
              String contextId,
              Message message,
              String apiInstanceId,
              int maxLoops,
              int maxConsecutiveErrors,
              ConversationService conversationService,
              A2AService a2aService,
              AgentsBroker broker,
              ExtensionsClient extensionsClient,
              SchedulerService schedulerService,
              CompletionCallback<String, Void> completionCallback) {
    this.promptBuilder = promptBuilder;
    this.llmClient = llmClient;
    this.toolExecutors = toolExecutors;
    this.taskId = taskId;
    this.contextId = contextId;
    this.message = message;
    this.apiInstanceId = apiInstanceId;
    this.maxLoops = maxLoops;
    this.maxConsecutiveErrors = maxConsecutiveErrors;
    this.schedulerService = schedulerService;
    this.completionCallback = new CompletionCallbackDecorator<>(completionCallback);
    this.broker = broker;
    this.extensionsClient = extensionsClient;
    this.conversationService = conversationService;
    this.a2aService = a2aService;
  }

  public void start() {
    setMDC();

    CheckedRunnable nextAction = conversationService.synchronizedConversation(contextId, broker, () -> {
      try {
        fetchContext();
      } catch (Exception e) {
        completionCallback.error(new ModuleException("Error recovering context %s: %s"
            .formatted(message.getContextId(), e.getMessage()), A2A, e));
        return null;
      }

      ConversationState conversationState = taskContext.getConversationState();
      AdditionalInputRequired additionalInputRequired = conversationState.getAdditionalInputRequired();

      try {
        if (additionalInputRequired != null) {
          conversationState.setTaskState(WORKING);
          conversationState.setAdditionalInputRequired(null);
          safeUpsertTaskContext();
          if (additionalInputRequired.getRequesterType() == A2A_TOOL) {
            return resumeHITLFromA2ATool();
          } else if (additionalInputRequired.getRequesterType() == LLM) {
            return resumeHITLByLLM();
          }
        }
      } catch (Exception e) {
        completionCallback.error(e);
        return null;
      }
      return this::nextLoop;
    });

    if (nextAction != null) {
      nextAction.run();
    }
  }

  private void fetchContext() {
    conversation = conversationService.getOrCreateConversation(message, contextId, broker);
    taskContext = conversationService.getTaskFor(message, conversation, taskId, broker);
  }

  private void nextLoop() {
    try {
      if (++currentLoop > maxLoops) {
        completeWithErrorAndUpdateState(new ModuleException("Maximum loops achieved without completing goal", MAX_LOOPS));
        return;
      }

      setMDC();
      promptBuilder.setCurrentLoopIteration(currentLoop)
          .setConversationHistory(conversationHistory.toString());

      taskContext.setCurrentIteration(currentLoop);

      final Iteration iteration = new Iteration();
      iteration.setNumber(currentLoop);
      taskContext.addIteration(iteration);

      completeAsyncInTCCL(llmClient.reasonNextStep(new LLMOrchestrationRequest(promptBuilder.build())),
                          schedulerService.ioScheduler(),
                          (llmOutput, llmException) -> {
                            setMDC();
                            if (llmException != null) {
                              llmException = unwrap(llmException);
                              secureLogger().debug("LLM call failed", llmException);
                              completeWithErrorAndUpdateState(new ModuleException(
                                                                                  "Cannot complete task due to temporal issue with the reasoning engine",
                                                                                  LLM_ERROR, llmException));
                              return;
                            }

                            iteration.setLlmOutput(llmOutput);
                            if (llmOutput.isGoalComplete()) {
                              taskContext.setConversationState(new ConversationState(COMPLETED, null));
                              safeUpsertTaskContext();
                              completionCallback
                                  .success(a2aService.completedJsonResult(conversation, taskContext, llmOutput.getResult()));
                              return;
                            }

                            if (llmOutput.isAdditionalInputRequired()) {
                              try {
                                updateTaskContextForHITLByLLM(taskContext, iteration, llmOutput);
                                completionCallback.success(a2aService.inputRequiredJsonResult(conversation, taskContext,
                                                                                              llmOutput.getResult()));
                              } catch (Exception e) {
                                completeWithErrorAndUpdateState(new MuleRuntimeException(createStaticMessage(
                                                                                                             "Could not update context when requesting additional data from client: "
                                                                                                                 + e.getMessage()),
                                                                                         e));
                              }

                              return;
                            }

                            final String toolToCall = llmOutput.getToolToCall();

                            if (!llmOutput.isGoalComplete() && noToolSelected(toolToCall)) {
                              addToHistoryAndLog("LLM didn't select any tool to call next, the goal is not yet complete and no addition data requested. "
                                  + "Re-evaluate and either select a tool or request additional data.",
                                                 "Reasoning: " + llmOutput.getReasoning());
                              if (maxConsecutiveErrorsReached()) {
                                completeWithErrorAndUpdateState(new ModuleException(createStaticMessage("LLM couldn't determine next step to complete the task"),
                                                                                    REASONING_ERROR));
                              } else {
                                nextLoop();
                              }
                              return;
                            }

                            String toolInput = llmOutput.getToolInput();
                            String reasoning = llmOutput.getReasoning();

                            var toolExecutor = toolExecutors.get(toolToCall);
                            if (toolExecutor == null) {
                              addToHistoryAndLog("LLM selected tool '" + toolToCall
                                  + "' but it doesn't exist. Reevaluate and make sure to select a tool from the provided list");

                              if (maxConsecutiveErrorsReached()) {
                                LOGGER
                                    .error("LLM selected a tool that doesn't exist: {}. Max consecutive errors limit reached. Task will fail",
                                           toolToCall);
                                completeWithErrorAndUpdateState(new ModuleException("Could not determine next steps for task completion",
                                                                                    REASONING_ERROR));
                              } else {
                                nextLoop();
                              }
                              return;
                            }

                            addToHistoryAndLog(
                                               "LLM selected tool " + toolToCall,
                                               "Reasoning was: " + nullSafe(reasoning) + ". Tool Input: " + toolInput);

                            var toolRequest =
                                new ToolRequest(toolToCall, toolInput, reasoning, currentLoop, taskContext, apiInstanceId);
                            executeTool(toolExecutor, toolRequest, toolToCall, toolResponse -> {
                              try {
                                resetConsecutiveErrors();
                                if (handleToolResponse(taskContext, llmOutput, toolToCall, toolResponse)) {
                                  nextLoop();
                                }
                              } catch (Exception e) {
                                completeWithErrorAndUpdateState(e);
                              }
                            });
                          });
    } catch (Exception e) {
      completionCallback.error(e);
    }
  }

  private boolean noToolSelected(String toolName) {
    return isBlank(toolName) || "null".equalsIgnoreCase(toolName);
  }

  private <T extends ToolResponse> void executeTool(Tool executor,
                                                    ToolRequest toolRequest,
                                                    String toolName,
                                                    Consumer<T> onSuccess) {

    completeAsyncInTCCL(executor.execute(toolRequest, extensionsClient),
                        schedulerService.ioScheduler(),
                        (toolResult, toolException) -> {
                          setMDC();
                          if (toolException != null) {
                            toolException = unwrap(toolException);
                            addToHistoryAndLog(
                                               "Execution of tool " + toolName + " failed. "
                                                   + "Analyze the error and determine whether the tool input needs to be adjusted, "
                                                   + "or if a different tool can be used instead",
                                               "Error message was: " + toolException.getMessage());

                            if (maxConsecutiveErrorsReached()) {
                              addToHistoryAndLog("Tool '%s' failed and maximum consecutive errors reached. Task '%s' will fail"
                                  .formatted(toolName, taskContext.getTaskId()),
                                                 "Error message was: " + toolException.getMessage());
                              completeWithErrorAndUpdateState(new ModuleException("Could not complete next step due to an error. Task failed.",
                                                                                  TOOL_ERROR, toolException));
                            } else {
                              nextLoop();
                            }
                          } else {
                            if (isBlank(toolResult.getResult())) {
                              addToHistoryAndLog(
                                                 "Execution of tool " + toolName
                                                     + " came back with no result. Determine whether the tool input needs to be adjusted, "
                                                     + "or if a different tool can be used instead");
                              if (maxConsecutiveErrorsReached()) {
                                completeWithErrorAndUpdateState(new ModuleException("Could not complete next step due to an error. Task failed.",
                                                                                    TOOL_ERROR));
                              } else {
                                nextLoop();
                              }
                              return;
                            }
                            onSuccess.accept((T) toolResult);
                          }
                        });
  }

  private boolean handleToolResponse(TaskContext taskContext,
                                     LLMOutput llmOutput,
                                     String toolName,
                                     ToolResponse toolResponse) {

    if (toolResponse.getToolType() == ToolType.A2A) {
      return handleA2AToolResponse(taskContext, toolName, llmOutput.getToolInput(), (A2AToolResponse) toolResponse);
    } else {
      handleStandardToolResponse(taskContext, llmOutput, toolName, toolResponse);
      return true;
    }
  }

  private void completeWithErrorAndUpdateState(Exception exception) {
    try {
      taskContext.setConversationState(new ConversationState(FAILED, exception.getMessage()));
      safeUpsertTaskContext();
    } finally {
      try {
        completionCallback.success(a2aService.asJsonResult(new Task.Builder()
            .id(taskContext.getTaskId())
            .contextId(conversation.getId())
            .status(new TaskStatus(FAILED, new Message.Builder()
                .messageId(randomUUID().toString())
                .taskId(taskContext.getTaskId())
                .role(AGENT)
                .parts(List.of(new TextPart(exception.getMessage())))
                .build(), OffsetDateTime.now()))
            .build()));
      } catch (Exception e) {
        completionCallback.error(e);
      }
    }
  }

  private boolean handleA2AToolResponse(TaskContext taskContext,
                                        String toolName,
                                        String toolInput,
                                        A2AToolResponse toolResponse) {

    final boolean inputRequired = toolResponse.isInputRequired();
    final String toolOutput = toolResponse.getResult();

    var toolContext = taskContext.getAgentToolContext(toolName);

    toolContext.setTaskState(toolResponse.getState());
    if (toolResponse.getContextId() != null) {
      toolContext.setContextId(toolResponse.getContextId());
    }

    if (toolResponse.getTaskId() != null) {
      toolContext.setTaskId(toolResponse.getTaskId());
    }

    addToolResponseToHistory(toolName, toolOutput);

    if (inputRequired) {
      updateTaskContextForToolCall(taskContext, toolName, toolInput, toolOutput, true);
      safeUpsertTaskContext();
      completionCallback.success(a2aService.inputRequiredJsonResult(conversation, taskContext, toolOutput));
      return false;
    }

    return true;
  }

  private void addToolResponseToHistory(String toolName, String toolOutput) {
    addToHistoryAndLog("Executed tool " + toolName, "Output was: " + toolOutput);
  }

  private void handleStandardToolResponse(TaskContext taskContext,
                                          LLMOutput llmOutput,
                                          String toolName,
                                          ToolResponse toolResponse) {
    addToolResponseToHistory(toolName, toolResponse.getResult());

    updateTaskContextForToolCall(taskContext, llmOutput.getToolToCall(), llmOutput.getReasoning(),
                                 llmOutput.getToolInput(), false);
  }

  private void addToHistoryAndLog(String message, String... secureAppendix) {
    LOGGER.info(message);
    conversationHistory.append(message);
    boolean addNewLine = message.endsWith("\n");

    if (secureAppendix != null) {
      for (String s : secureAppendix) {
        secureLogger().debug(s);
        conversationHistory.append(s);
        addNewLine = addNewLine || s.endsWith("\n");
      }
    }

    if (addNewLine) {
      conversationHistory.append('\n');
    }
  }

  private CheckedRunnable resumeHITLByLLM() {
    conversationHistory = new StringBuilder(taskContext.getConversationHistory());

    // decrement it by one since currentLoop is zero based
    currentLoop = taskContext.getCurrentIteration() - 1;

    LOGGER.info("Resuming conversation {} from iteration {}", conversation.getId(), taskContext.getCurrentIteration());
    return this::nextLoop;
  }

  private CheckedRunnable resumeHITLFromA2ATool() {
    conversationHistory = new StringBuilder(taskContext.getConversationHistory());

    Iteration lastIteration = taskContext.getLastIteration();
    ToolExecution waitingToolResult = lastIteration.getToolExecution();
    String waitingToolName = waitingToolResult.getName();
    String toolInput = promptBuilder.getUserPrompt();
    String reasoning = waitingToolResult.getOutput();

    Tool tool = toolExecutors.get(waitingToolName);
    if (tool == null) {
      throw new ModuleException("Could not find A2A agent config to resume the tool - " + waitingToolResult, A2A);
    }

    currentLoop = taskContext.getCurrentIteration() - 1;
    var toolRequest = new ToolRequest(waitingToolName, toolInput, reasoning, currentLoop, taskContext, apiInstanceId);

    return () -> executeTool(tool, toolRequest, waitingToolName, toolResult -> {
      try {
        if (handleA2AToolResponse(taskContext, waitingToolName, toolInput, (A2AToolResponse) toolResult)) {
          nextLoop();
        }
      } catch (Exception e) {
        completionCallback.error(e);
      }
    });
  }

  private void updateTaskContextForHITLByLLM(TaskContext taskContext, Iteration iteration, LLMOutput response)
      throws ObjectStoreException {
    ConversationState conversationState = taskContext.getConversationState();
    conversationState.setAdditionalInputRequired(new AdditionalInputRequired(LLM, response.getToolToCall()));
    conversationState.setTaskState(INPUT_REQUIRED);

    iteration.setLlmOutput(response);
    iteration.setToolExecution(null);

    safeUpsertTaskContext();
  }

  private void safeUpsertTaskContext() {
    try {
      conversationService.upsert(taskContext, broker);
    } catch (Exception e) {
      LOGGER.error("Exception found updating TaskContext", e);
    }
  }

  private void updateTaskContextForToolCall(TaskContext taskContext,
                                            String toolName,
                                            String toolInput,
                                            String toolResponse,
                                            boolean requiresAdditionalInput) {

    ToolExecution toolExecution = new ToolExecution(toolName, toolInput, toolResponse);
    taskContext.getLastIteration().setToolExecution(toolExecution);

    if (requiresAdditionalInput) {
      ConversationState conversationState = taskContext.getConversationState();
      conversationState.setAdditionalInputRequired(new AdditionalInputRequired(A2A_TOOL, toolName));
      conversationState.setTaskState(INPUT_REQUIRED);
    }

    taskContext.setConversationHistory(conversationHistory.toString());
  }

  private void resetConsecutiveErrors() {
    consecutiveErrors = 0;
  }

  private boolean maxConsecutiveErrorsReached() {
    return ++consecutiveErrors >= maxConsecutiveErrors;
  }

  private void setMDC() {
    MDC.put(AGENT_MDC_KEY, broker.getConfigName());
    MDC.put(ITERATION_MDC_KEY, String.valueOf(currentLoop));
    MDC.put(TASK_ID_MDC_KEY, taskId);
    MDC.put(CONTEXT_ID_MDC_KEY, contextId);
    MDC.put(API_INSTANCE_ID_MDC_KEY, apiInstanceId);
  }

  private String nullSafe(String value) {
    return value == null ? "" : value;
  }

  private class CompletionCallbackDecorator<T, A> implements CompletionCallback<T, A> {

    private final CompletionCallback<T, A> delegate;
    private final Map<String, String> originalMdcContext;

    public CompletionCallbackDecorator(CompletionCallback<T, A> delegate) {
      this.delegate = delegate;
      originalMdcContext = MDC.getCopyOfContextMap();
    }

    @Override
    public void success(Result<T, A> result) {
      try {
        delegate.success(result);
      } finally {
        restoreMdcContext();
      }
    }

    @Override
    public void error(Throwable throwable) {
      try {
        delegate.error(throwable);
      } finally {
        restoreMdcContext();
      }
    }

    private void restoreMdcContext() {
      MDC.setContextMap(originalMdcContext);
    }
  }
}
