/*
 * (c) 2025 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file.
 */
package com.mulesoft.modules.agent.conductor.internal.operation.loop;

import static com.mulesoft.modules.agent.conductor.internal.tool.ToolType.A2A;
import static com.mulesoft.modules.agent.conductor.internal.error.ConductorErrorTypes.LLM_ERROR;
import static com.mulesoft.modules.agent.conductor.internal.error.ConductorErrorTypes.MAX_LOOPS;
import static com.mulesoft.modules.agent.conductor.internal.error.ConductorErrorTypes.REASONING_ERROR;
import static com.mulesoft.modules.agent.conductor.internal.error.ConductorErrorTypes.TOOL_ERROR;
import static com.mulesoft.modules.agent.conductor.internal.state.model.AdditionalInputRequired.RequesterType.A2A_TOOL;
import static com.mulesoft.modules.agent.conductor.internal.state.model.AdditionalInputRequired.RequesterType.LLM;

import static org.mule.runtime.api.i18n.I18nMessageFactory.createStaticMessage;
import static org.mule.runtime.core.api.util.StringUtils.isBlank;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.api.store.ObjectStore;
import org.mule.runtime.api.store.ObjectStoreException;
import org.mule.runtime.core.api.util.func.CheckedFunction;
import org.mule.runtime.extension.api.exception.ModuleException;
import org.mule.runtime.extension.api.runtime.operation.Result;
import org.mule.runtime.extension.api.runtime.process.CompletionCallback;

import com.mulesoft.modules.agent.conductor.api.model.Orchestration;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolRequest;
import com.mulesoft.modules.agent.conductor.internal.llm.LLMOrchestrationRequest;
import com.mulesoft.modules.agent.conductor.internal.llm.client.LLMClient;
import com.mulesoft.modules.agent.conductor.internal.prompt.PromptBuilder;
import com.mulesoft.modules.agent.conductor.internal.state.TaskContextService;
import com.mulesoft.modules.agent.conductor.internal.state.model.AdditionalInputRequired;
import com.mulesoft.modules.agent.conductor.internal.state.model.ConversationState;
import com.mulesoft.modules.agent.conductor.internal.state.model.Iteration;
import com.mulesoft.modules.agent.conductor.internal.state.model.LLMOutput;
import com.mulesoft.modules.agent.conductor.internal.state.model.TaskContext;
import com.mulesoft.modules.agent.conductor.internal.state.model.ToolExecution;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolHandler;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolResponse;
import com.mulesoft.modules.agent.conductor.internal.tool.a2a.A2AToolResponse;
import com.mulesoft.modules.agent.conductor.internal.tool.a2a.A2AToolService;

import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;

import org.slf4j.Logger;
import org.slf4j.MDC;

class Loop {

  private static final Logger LOGGER = getLogger(Loop.class);

  private static final String AGENT_MDC_KEY = "agent";
  private static final String ITERATION_MDC_KEY = "iteration";
  private static final String TASK_ID_MDC_KEY = "taskId";
  private static final String CONTEXT_ID_MDC_KEY = "contextId";

  private final PromptBuilder promptBuilder;
  private final LLMClient llmClient;
  private final A2AToolService a2aToolService;
  private final Map<String, ToolHandler> toolHandlers;
  private final String taskId;
  private final String contextId;
  private final int maxLoops;
  private final int maxConsecutiveErrors;
  private final TaskContextService taskContextService;
  private final ObjectStore<TaskContext> taskContextStore;
  private final String configRef;
  private final CompletionCallback<Orchestration, Void> completionCallback;

  private int currentLoop = 0;
  private int consecutiveErrors = 0;
  private StringBuilder conversationHistory = new StringBuilder("");
  private TaskContext taskContext;

  public Loop(PromptBuilder promptBuilder,
              LLMClient llmClient,
              A2AToolService a2aToolService,
              Map<String, ToolHandler> toolHandlers,
              String taskId,
              String contextId,
              int maxLoops,
              int maxConsecutiveErrors,
              TaskContextService taskContextService,
              ObjectStore<TaskContext> taskContextStore,
              String configRef,
              CompletionCallback<Orchestration, Void> completionCallback) {
    this.promptBuilder = promptBuilder;
    this.llmClient = llmClient;
    this.a2aToolService = a2aToolService;
    this.toolHandlers = toolHandlers;
    this.taskId = taskId;
    this.contextId = contextId;
    this.maxLoops = maxLoops;
    this.maxConsecutiveErrors = maxConsecutiveErrors;
    this.completionCallback = new CompletionCallbackDecorator<>(completionCallback);
    this.taskContextStore = taskContextStore;
    this.configRef = configRef;
    this.taskContextService = taskContextService;
  }

  public void start() {
    setMDC();
    try {
      taskContext = getOrCreateTaskContext();
    } catch (Exception e) {
      completionCallback.error(new MuleRuntimeException(createStaticMessage(
                                                                            "Error while fetching the task context for contextId %s: %s"
                                                                                .formatted(contextId, e.getMessage())),
                                                        e));
      return;
    }

    try {
      ConversationState conversationState = taskContext.getConversationState();
      if (conversationState.isCompleted() || conversationState.isFailed()) {
        // this is a follow-up conversation to a task we thought completed
        taskContext.reset();
        nextLoop();
      } else {
        AdditionalInputRequired additionalInputRequired = conversationState.getAdditionalInputRequired();
        if (additionalInputRequired != null) {
          if (additionalInputRequired.getRequesterType() == A2A_TOOL) {
            resumeHITLFromA2ATool();
          } else if (additionalInputRequired.getRequesterType() == LLM) {
            resumeHITLByLLM();
          }
        } else {
          // this is a fresh conversation
          nextLoop();
        }
      }
    } catch (Exception e) {
      completionCallback.error(e);
    }
  }

  private void nextLoop() {
    try {
      if (++currentLoop > maxLoops) {
        completionCallback.error(new ModuleException("Maximum loops achieved without completing goal", MAX_LOOPS));
        return;
      }

      setMDC();
      promptBuilder.setCurrentLoopIteration(currentLoop)
          .setConversationHistory(conversationHistory.toString());

      taskContext.setCurrentIteration(currentLoop);

      final Iteration iteration = new Iteration();
      iteration.setNumber(currentLoop);
      taskContext.addIteration(iteration);

      llmClient.reasonNextStep(new LLMOrchestrationRequest(promptBuilder.build()))
          .whenComplete((llmOutput, llmException) -> {
            setMDC();
            if (llmException != null) {
              completeWithErrorAndUpdateState(new ModuleException(
                                                                  "LLM call failed: " + llmException.getMessage(),
                                                                  LLM_ERROR, llmException));
              return;
            }

            iteration.setLlmOutput(llmOutput);
            if (llmOutput.isGoalComplete()) {
              taskContext.setCurrentState(new ConversationState(true, false, null));
              safeUpsertTaskContext();
              completionCallback.success(Result.<Orchestration, Void>builder()
                  .output(new Orchestration(taskId, contextId, llmOutput.getResult(), llmOutput.isGoalComplete(),
                                            llmOutput.isAdditionalInputNeeded(),
                                            currentLoop, llmOutput.getReasoning()))
                  .build());
              return;
            }

            if (llmOutput.isAdditionalInputNeeded()) {
              try {
                updateTaskContextForHITLByLLM(taskContext, iteration, llmOutput);
                completionCallback.success(Result.<Orchestration, Void>builder()
                    .output(new Orchestration(taskId, contextId, llmOutput.getResult(), false, true, currentLoop,
                                              llmOutput.getReasoning()))
                    .build());
              } catch (Exception e) {
                completeWithErrorAndUpdateState(new MuleRuntimeException(createStaticMessage(
                                                                                             "Could not update context when requesting additional data from client: "
                                                                                                 + e.getMessage()),
                                                                         e));
              }

              return;
            }

            final String toolToCall = llmOutput.getToolToCall();

            if (!llmOutput.isGoalComplete() && isBlank(toolToCall)) {
              addToHistory("LLM didn't select any tool to call next, the goal is not yet complete and no addition data requested. "
                  + "Re-evaluate and either select a tool, requested additional data or mark the goal as completed.");
              if (maxConsecutiveErrorsReached()) {
                completeWithErrorAndUpdateState(new ModuleException(createStaticMessage(
                                                                                        "LLM couldn't select any tool to proceed forward. Reasoning: "
                                                                                            + llmOutput.getReasoning()),
                                                                    REASONING_ERROR));
              } else {
                nextLoop();
              }
              return;
            }

            String toolInput = llmOutput.getToolInput();
            String reasoning = llmOutput.getReasoning();

            var tool = toolHandlers.get(toolToCall);
            if (tool == null) {
              addToHistory("LLM selected tool '" + toolToCall
                  + "' but it doesn't exist. Reevaluate and make sure to select a tool from the provided list");

              if (maxConsecutiveErrorsReached()) {
                completeWithErrorAndUpdateState(new ModuleException(createStaticMessage(
                                                                                        "LLM selected a tool that doesn't exist: "
                                                                                            + toolToCall),
                                                                    REASONING_ERROR));
              } else {
                nextLoop();
              }
              return;
            }

            addToHistory(
                         "LLM selected tool " + toolToCall + ". Reasoning was: " + nullSafe(reasoning) + ".",
                         "Tool Input: " + toolInput);

            var toolRequest = new ToolRequest(toolToCall, toolInput, reasoning, currentLoop, taskContext);
            applyToolHandler(tool.getHandler(), toolRequest, toolToCall, toolResponse -> {
              try {
                resetConsecutiveErrors();
                if (handleToolResponse(taskContext, llmOutput, toolToCall, toolResponse)) {
                  nextLoop();
                }
              } catch (Exception e) {
                completeWithErrorAndUpdateState(e);
              }
            });
          });
    } catch (Exception e) {
      completionCallback.error(e);
    }
  }

  private <T> void applyToolHandler(CheckedFunction<ToolRequest, CompletableFuture<T>> handler,
                                    ToolRequest toolRequest,
                                    String toolName,
                                    Consumer<T> onSuccess) {

    handler.apply(toolRequest).whenComplete((toolResult, toolException) -> {
      setMDC();
      if (toolException != null) {
        addToHistory(
                     "Execution of tool " + toolName + " failed. Error message was: " + toolException.getMessage()
                         + " . Analyze the error and determine whether the tool input needs to be adjusted, " +
                         "or if a different tool can be used instead");

        if (maxConsecutiveErrorsReached()) {
          completionCallback.error(new ModuleException("Error executing tool %s: %s"
              .formatted(toolName, toolException.getMessage()), TOOL_ERROR, toolException));

        } else {
          nextLoop();
        }
      } else {
        onSuccess.accept(toolResult);
      }
    });
  }

  private boolean handleToolResponse(TaskContext taskContext,
                                     LLMOutput llmOutput,
                                     String toolName,
                                     ToolResponse toolResponse) {

    if (toolResponse.getToolType() == A2A) {
      return handleA2AToolResponse(taskContext, toolName, llmOutput.getToolInput(), (A2AToolResponse) toolResponse);
    } else {
      handleStandardToolResponse(taskContext, llmOutput, toolName, toolResponse);
      return true;
    }
  }

  private void completeWithErrorAndUpdateState(Exception exception) {
    try {
      taskContext.setCurrentState(new ConversationState(false, true, exception.getMessage()));
      safeUpsertTaskContext();
    } finally {
      completionCallback.error(exception);
    }
  }

  private boolean handleA2AToolResponse(TaskContext taskContext,
                                        String toolName,
                                        String toolInput,
                                        A2AToolResponse toolResponse) {

    final boolean inputRequired = toolResponse.isInputRequired();
    final String toolOutput = toolResponse.getResult();

    var toolContext = taskContext.getAgentToolContext(toolName);

    toolContext.setTaskId(taskContext.getTaskId());
    if (toolResponse.getContextId() != null) {
      toolContext.setContextId(toolResponse.getContextId());
    }

    addToolResponseToHistory(toolName, toolOutput);

    if (inputRequired) {
      updateTaskContextForToolCall(taskContext, toolName, toolInput, toolOutput, true);
      safeUpsertTaskContext();
      completionCallback.success(Result.<Orchestration, Void>builder()
          .output(new Orchestration(taskId, contextId, toolOutput, false, true, currentLoop,
                                    "Agent " + configRef + " requires additional input to proceed"))

          .build());
      return false;
    }

    return true;
  }

  private void addToolResponseToHistory(String toolName, String toolOutput) {
    addToHistory("Executed tool " + toolName + ".", "Output was: " + toolOutput);
  }

  private void handleStandardToolResponse(TaskContext taskContext,
                                          LLMOutput llmOutput,
                                          String toolName,
                                          ToolResponse toolResponse) {
    addToolResponseToHistory(toolName, toolResponse.getResult());

    updateTaskContextForToolCall(taskContext, llmOutput.getToolToCall(), llmOutput.getReasoning(),
                                 llmOutput.getToolInput(), false);
  }

  private void addToHistory(String message, String... additional) {
    LOGGER.info(message);
    conversationHistory.append(message);
    boolean addNewLine = message.endsWith("\n");

    if (additional != null) {
      for (String s : additional) {
        LOGGER.debug(s);
        conversationHistory.append(s);
        addNewLine = addNewLine || s.endsWith("\n");
      }
    }

    if (addNewLine) {
      conversationHistory.append('\n');
    }
  }

  private void resumeHITLByLLM() {
    conversationHistory = new StringBuilder(taskContext.getConversationHistory());

    // decrement it by one since currentLoop is zero based
    currentLoop = taskContext.getCurrentIteration() - 1;

    LOGGER.info("Resuming the conversation from iteration {}", taskContext.getCurrentIteration());
    nextLoop();
  }

  private void resumeHITLFromA2ATool() {
    Iteration lastIteration = taskContext.getLastIteration();
    ToolExecution waitingTool = lastIteration.getToolExecution();
    String taskId = taskContext.getTaskId();
    String waitingToolName = waitingTool.getName();
    String toolInput = promptBuilder.getUserPrompt();
    String reasoning = waitingTool.getOutput();

    String a2AClientConfigRef = a2aToolService.getA2AClientConfigRef(waitingToolName);
    if (a2AClientConfigRef == null) {
      throw new MuleRuntimeException(createStaticMessage("Could not find A2A agent config to resume the tool - " + waitingTool));
    }

    currentLoop = taskContext.getCurrentIteration() - 1;
    var toolRequest = new ToolRequest(waitingToolName, toolInput, reasoning, currentLoop, taskContext);

    CheckedFunction<ToolRequest, CompletableFuture<ToolResponse>> a2aToolHandler =
        a2aToolService.createHandler(a2AClientConfigRef);

    applyToolHandler(a2aToolHandler, toolRequest, waitingToolName, toolResult -> {
      try {
        handleA2AToolResponse(taskContext, waitingToolName, toolInput, (A2AToolResponse) toolResult);
        nextLoop();
      } catch (Exception e) {
        completionCallback.error(e);
      }
    });
  }

  private TaskContext getOrCreateTaskContext() throws ObjectStoreException {
    if (taskContext == null) {
      taskContext = taskContextService.getOrCreate(taskId, contextId, configRef, taskContextStore);
    }

    return taskContext;
  }

  private void updateTaskContextForHITLByLLM(TaskContext taskContext, Iteration iteration, LLMOutput response)
      throws ObjectStoreException {
    ConversationState conversationState = taskContext.getConversationState();
    conversationState.setAdditionalInputRequired(new AdditionalInputRequired(LLM, response.getToolToCall()));
    conversationState.setCompleted(false);
    conversationState.setFailed(false);

    iteration.setLlmOutput(response);
    iteration.setToolExecution(null);

    taskContextService.upsertTaskContext(taskContext, configRef, taskContextStore);
  }

  private void safeUpsertTaskContext() {
    try {
      taskContextService.upsertTaskContext(taskContext, configRef, taskContextStore);
    } catch (Exception e) {
      LOGGER.error("Exception found updating TaskContext");
    }
  }

  private void updateTaskContextForToolCall(TaskContext taskContext,
                                            String toolName,
                                            String toolInput,
                                            String toolResponse,
                                            boolean requiresAdditionalInput) {

    ToolExecution toolExecution = new ToolExecution(toolName, toolInput, toolResponse);
    taskContext.getLastIteration().setToolExecution(toolExecution);

    if (requiresAdditionalInput) {
      ConversationState conversationState = taskContext.getConversationState();
      conversationState.setAdditionalInputRequired(new AdditionalInputRequired(A2A_TOOL, toolName));
      conversationState.setCompleted(false);
      conversationState.setFailed(false);
    }

    taskContext.setConversationHistory(conversationHistory.toString());
  }

  private void resetConsecutiveErrors() {
    consecutiveErrors = 0;
  }

  private boolean maxConsecutiveErrorsReached() {
    return ++consecutiveErrors >= maxConsecutiveErrors;
  }

  private void setMDC() {
    MDC.put(AGENT_MDC_KEY, configRef);
    MDC.put(ITERATION_MDC_KEY, String.valueOf(currentLoop));
    MDC.put(TASK_ID_MDC_KEY, taskId);
    MDC.put(CONTEXT_ID_MDC_KEY, contextId);
  }

  private String nullSafe(String value) {
    return value == null ? "" : value;
  }

  private class CompletionCallbackDecorator<T, A> implements CompletionCallback<T, A> {

    private final CompletionCallback<T, A> delegate;
    private final Map<String, String> originalMdcContext;

    public CompletionCallbackDecorator(CompletionCallback<T, A> delegate) {
      this.delegate = delegate;
      originalMdcContext = MDC.getCopyOfContextMap();
    }

    @Override
    public void success(Result<T, A> result) {
      try {
        delegate.success(result);
      } finally {
        restoreMdcContext();
      }
    }

    @Override
    public void error(Throwable throwable) {
      try {
        delegate.error(throwable);
      } finally {
        restoreMdcContext();
      }
    }

    private void restoreMdcContext() {
      MDC.setContextMap(originalMdcContext);
    }
  }
}
