/*
 * (c) 2025 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file.
 */
package com.mulesoft.modules.agent.broker.internal.operation.loop;

import static com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes.A2A;
import static com.mulesoft.modules.agent.broker.internal.log.SecureLogger.secureLogger;
import static com.mulesoft.modules.agent.broker.internal.state.model.AdditionalInputRequired.RequesterType.A2A_TOOL;
import static com.mulesoft.modules.agent.broker.internal.state.model.AdditionalInputRequired.RequesterType.LLM;
import static com.mulesoft.modules.agent.broker.internal.util.ConcurrencyUtils.completeAsyncInTCCL;
import static com.mulesoft.modules.agent.broker.internal.util.ExceptionUtils.unwrap;

import static java.util.UUID.randomUUID;

import static io.a2a.spec.Message.Role.AGENT;
import static io.a2a.spec.TaskState.COMPLETED;
import static io.a2a.spec.TaskState.FAILED;
import static io.a2a.spec.TaskState.INPUT_REQUIRED;
import static io.a2a.spec.TaskState.WORKING;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.runtime.api.scheduler.SchedulerService;
import org.mule.runtime.core.api.util.func.CheckedRunnable;
import org.mule.runtime.extension.api.client.ExtensionsClient;
import org.mule.runtime.extension.api.exception.ModuleException;
import org.mule.runtime.extension.api.runtime.operation.Result;
import org.mule.runtime.extension.api.runtime.process.CompletionCallback;

import com.mulesoft.modules.agent.broker.internal.extension.AgentsBroker;
import com.mulesoft.modules.agent.broker.internal.extension.connection.LLMClient.LLMSession;
import com.mulesoft.modules.agent.broker.internal.state.ConversationService;
import com.mulesoft.modules.agent.broker.internal.state.model.AdditionalInputRequired;
import com.mulesoft.modules.agent.broker.internal.state.model.Conversation;
import com.mulesoft.modules.agent.broker.internal.state.model.ConversationState;
import com.mulesoft.modules.agent.broker.internal.state.model.Iteration;
import com.mulesoft.modules.agent.broker.internal.state.model.LLMOutput;
import com.mulesoft.modules.agent.broker.internal.state.model.TaskContext;
import com.mulesoft.modules.agent.broker.internal.state.model.ToolSelection;
import com.mulesoft.modules.agent.broker.internal.tool.Tool;
import com.mulesoft.modules.agent.broker.internal.tool.ToolRequest;
import com.mulesoft.modules.agent.broker.internal.tool.ToolResponse;
import com.mulesoft.modules.agent.broker.internal.tool.ToolType;
import com.mulesoft.modules.agent.broker.internal.tool.a2a.A2AService;
import com.mulesoft.modules.agent.broker.internal.tool.a2a.A2AToolResponse;

import java.time.OffsetDateTime;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;

import io.a2a.spec.Message;
import io.a2a.spec.Task;
import io.a2a.spec.TaskStatus;
import io.a2a.spec.TextPart;
import org.slf4j.Logger;
import org.slf4j.MDC;

class Loop {

  private static final Logger LOGGER = getLogger(Loop.class);

  private static final String AGENT_MDC_KEY = "agent";
  private static final String ITERATION_MDC_KEY = "iteration";
  private static final String TASK_ID_MDC_KEY = "taskId";
  private static final String CONTEXT_ID_MDC_KEY = "contextId";
  private static final String API_INSTANCE_ID_MDC_KEY = "apiInstanceId";

  private final LLMSession llmSession;
  private final Map<String, Tool> toolExecutors;
  private final String taskId;
  private final String contextId;
  private final Message message;
  private final String apiInstanceId;
  private final int maxLoops;
  private final int maxConsecutiveErrors;
  private final ConversationService conversationService;
  private final A2AService a2aService;
  private final AgentsBroker broker;
  private final ExtensionsClient extensionsClient;
  private final SchedulerService schedulerService;
  private final CompletionCallback<String, Void> completionCallback;

  private int currentLoop = 0;
  private int consecutiveErrors = 0;
  private Iteration iteration;
  private Conversation conversation;
  private TaskContext taskContext;

  Loop(LLMSession llmSession, Map<String, Tool> toolExecutors, String taskId, String contextId, Message message,
       String apiInstanceId, int maxLoops, int maxConsecutiveErrors, ConversationService conversationService,
       A2AService a2aService, AgentsBroker broker, ExtensionsClient extensionsClient, SchedulerService schedulerService,
       CompletionCallback<String, Void> completionCallback) {
    this.llmSession = llmSession;
    this.toolExecutors = toolExecutors;
    this.taskId = taskId;
    this.contextId = contextId;
    this.message = message;
    this.apiInstanceId = apiInstanceId;
    this.maxLoops = maxLoops;
    this.maxConsecutiveErrors = maxConsecutiveErrors;
    this.schedulerService = schedulerService;
    this.completionCallback = new CompletionCallbackDecorator<>(completionCallback);
    this.broker = broker;
    this.extensionsClient = extensionsClient;
    this.conversationService = conversationService;
    this.a2aService = a2aService;
  }

  public void start() {
    setMDC();

    CheckedRunnable nextAction = conversationService.synchronizedConversation(contextId, broker, () -> {
      ConversationState conversationState;
      try {
        fetchContext();
        taskContext.getIterations().forEach(llmSession::addIteration);
        conversationState = taskContext.getConversationState();
      } catch (Exception e) {
        completionCallback
            .error(new ModuleException("Error recovering context %s: %s".formatted(message.getContextId(), e.getMessage()), A2A,
                                       e));
        return null;
      }

      AdditionalInputRequired additionalInputRequired = conversationState.getAdditionalInputRequired();

      try {
        if (additionalInputRequired != null) {
          currentLoop = additionalInputRequired.getLoopNumber();
          taskContext.setConversationState(new ConversationState(WORKING));

          if (additionalInputRequired.getRequesterType() == A2A_TOOL) {
            return resumeHITLFromA2ATool();
          } else if (additionalInputRequired.getRequesterType() == LLM) {
            return resumeHITLByLLM(additionalInputRequired);
          }
        }
      } catch (Exception e) {
        completionCallback.error(e);
        return null;
      }
      return () -> {
        var iteration = taskContext.newIteration();
        iteration.setUserPrompt(llmSession.getInitialRequest().getPrompt());

        nextLoop(iteration);
      };
    });

    if (nextAction != null) {
      nextAction.run();
    }
  }

  private void fetchContext() {
    conversation = conversationService.getOrCreateConversation(message, contextId, broker);
    taskContext = conversationService.getTaskFor(message, conversation, taskId, broker);
  }

  private void nextLoop() {
    nextLoop(taskContext.newIteration());
  }

  private void nextLoop(final Iteration iteration) {
    try {
      if (++currentLoop > maxLoops) {
        updateStateAndCompleteWithError("Maximum loops achieved without completing goal");
        return;
      }

      setMDC();
      this.iteration = iteration;
      llmSession.addIteration(iteration);

      completeAsyncInTCCL(llmSession.getNext(), schedulerService.ioScheduler(), (either, llmException) -> {
        setMDC();
        if (llmException != null) {
          llmException = unwrap(llmException);
          secureLogger().debug("LLM call failed", llmException);
          updateStateAndCompleteWithError("Cannot complete task due to issue accessing reasoning engine");
          return;
        }

        either.apply(this::onLLMOutput, this::onToolSelection);
      });
    } catch (Exception e) {
      completionCallback.error(e);
    }
  }

  private void onToolSelection(ToolSelection selection) {
    iteration.setToolSelection(selection);

    final String toolToCall = selection.getToolId();
    String toolInput = selection.getInput();

    reasoningLog("LLM selected tool " + toolToCall, "Tool Input: " + toolInput);

    var toolExecutor = toolExecutors.get(toolToCall);

    if (toolExecutor == null) {
      ToolResponse errorResponse = new ToolResponse("Selected function '" + toolToCall + "' doesn't exist", selection, null);
      iteration.setToolResponse(errorResponse);
      llmSession.addToolResponse(errorResponse);
      reasoningLog(errorResponse.getResult());

      if (maxConsecutiveErrorsReached()) {
        updateStateAndCompleteWithError("I'm sorry, but I was unable to determine next step. Task cannot be completed");
      } else {
        nextLoop();
      }
      return;
    }

    var toolRequest = new ToolRequest(selection, taskContext, apiInstanceId);
    executeTool(toolExecutor, toolRequest, toolToCall, toolResponse -> {
      try {
        resetConsecutiveErrors();
        if (handleToolResponse(toolToCall, toolResponse)) {
          iteration.setToolResponse(toolResponse);
          llmSession.addToolResponse(toolResponse);
          nextLoop();
        }
      } catch (Exception e) {
        LOGGER.error("Error processing response for tool {}", toolToCall, e);
        completeAsInternalError();
      }
    });
  }

  private void onLLMOutput(LLMOutput output) {
    iteration.setLlmOutput(output);
    if (output.isGoalComplete()) {
      taskContext.setConversationState(new ConversationState(COMPLETED));
      safeUpsertTaskContext();
      completionCallback.success(a2aService.completedJsonResult(conversation, taskContext, output.getResult()));
    } else if (output.isAdditionalInputRequired()) {
      try {
        updateTaskContextForHITLByLLM();
        completionCallback.success(a2aService.inputRequiredJsonResult(conversation, taskContext,
                                                                      List.of(new TextPart(output.getResult())), List.of()));
      } catch (Exception e) {
        LOGGER.error("Could not update context when requesting additional data from client", e);
        completeAsInternalError();
      }
    } else if (output.isGoalFailed()) {
      updateStateAndCompleteWithError(output.getResult());
    } else {
      String errorMessage = "LLM output is logically invalid. Goal is neither completed, failed or requesting additional input";
      reasoningLog(errorMessage);
      if (maxConsecutiveErrorsReached()) {
        completeAsInternalError();
      } else {
        var next = taskContext.newIteration();
        next.setUserPrompt(errorMessage + ". Re-evaluate, provide a valid response or select a tool");
        nextLoop(next);
      }
    }
  }

  private void completeAsInternalError() {
    updateStateAndCompleteWithError("Could not complete task due to an internal error");
  }

  private <T extends ToolResponse> void executeTool(Tool executor, ToolRequest toolRequest, String toolName,
                                                    Consumer<T> onSuccess) {

    completeAsyncInTCCL(executor.execute(toolRequest, extensionsClient), schedulerService.ioScheduler(),
                        (toolResponse, toolException) -> {
                          setMDC();
                          if (toolException != null) {
                            toolException = unwrap(toolException);
                            String message = "Execution of tool " + toolName + " failed.";
                            String secureAppendix = "Error message was: " + toolException.getMessage();

                            reasoningLog(message, secureAppendix);

                            var failResponse =
                                new ToolResponse(message + " " + secureAppendix, toolRequest.selection(), executor.getToolType());
                            iteration.setToolResponse(failResponse);
                            safeUpsertTaskContext();
                            llmSession.addToolResponse(failResponse);

                            if (maxConsecutiveErrorsReached()) {
                              updateStateAndCompleteWithError("Could not complete next step due to an error. Task failed.");
                            } else {
                              nextLoop();
                            }
                          } else {
                            onSuccess.accept((T) toolResponse);
                          }
                        });
  }

  private boolean handleToolResponse(String toolName, ToolResponse toolResponse) {
    reasoningLog("Executed tool " + toolName, "Output was: " + toolResponse.getResult());

    if (toolResponse.getToolType() == ToolType.A2A) {
      return handleA2AToolResponse(toolName, (A2AToolResponse) toolResponse);
    } else {
      return true;
    }
  }

  private void updateStateAndCompleteWithError(String message) {
    reasoningLog("Task failed: " + message);

    try {
      taskContext.setConversationState(new ConversationState(FAILED, message));
      safeUpsertTaskContext();
    } finally {
      try {
        completionCallback
            .success(a2aService.asJsonResult(new Task.Builder().id(taskContext.getTaskId()).contextId(conversation.getId())
                .status(new TaskStatus(FAILED,
                                       new Message.Builder().messageId(randomUUID().toString()).taskId(taskContext.getTaskId())
                                           .role(AGENT).parts(List.of(new TextPart(message))).build(),
                                       OffsetDateTime.now()))
                .build()));
      } catch (Exception e) {
        completionCallback.error(e);
      }
    }
  }

  private boolean handleA2AToolResponse(String toolName, A2AToolResponse toolResponse) {
    final boolean inputRequired = toolResponse.isInputRequired();
    var toolContext = taskContext.getAgentToolContext(toolName);

    toolContext.setTaskState(toolResponse.getState());
    if (toolResponse.getContextId() != null) {
      toolContext.setContextId(toolResponse.getContextId());
    }

    if (toolResponse.getTaskId() != null) {
      toolContext.setTaskId(toolResponse.getTaskId());
    }

    if (inputRequired) {
      taskContext
          .setConversationState(new ConversationState(INPUT_REQUIRED,
                                                      new AdditionalInputRequired(A2A_TOOL, iteration.getNumber(), toolName)));
      safeUpsertTaskContext();
      completionCallback.success(a2aService.inputRequiredJsonResult(conversation, taskContext, toolResponse.getMessageParts(),
                                                                    toolResponse.getArtifacts()));
      return false;
    }

    return true;
  }

  private void reasoningLog(String message, String... secureAppendix) {
    LOGGER.info(message);

    if (secureAppendix != null) {
      for (String s : secureAppendix) {
        secureLogger().debug(s);
      }
    }
  }

  private CheckedRunnable resumeHITLByLLM(AdditionalInputRequired additionalInputRequired) {
    return () -> {
      LOGGER.info("Resuming conversation {} from iteration {}", conversation.getId(), additionalInputRequired.getLoopNumber());

      var iteration = taskContext.newIteration();
      iteration.setUserPrompt(llmSession.getInitialRequest().getPrompt());
      nextLoop(iteration);
    };
  }

  private CheckedRunnable resumeHITLFromA2ATool() {
    iteration = taskContext.getLastIteration();

    String waitingToolName = iteration.getToolSelection().getToolId();
    Tool tool = toolExecutors.get(waitingToolName);
    if (tool == null) {
      throw new ModuleException("Could not find A2A agent config to resume the tool - " + waitingToolName, A2A);
    }

    ToolSelection selection = iteration.getToolSelection();
    selection.setInput(llmSession.getInitialRequest().getPrompt());

    var toolRequest = new ToolRequest(selection, taskContext, apiInstanceId);

    return () -> executeTool(tool, toolRequest, waitingToolName, toolResponse -> {
      try {
        if (handleA2AToolResponse(waitingToolName, (A2AToolResponse) toolResponse)) {
          iteration.setToolResponse(toolResponse);
          safeUpsertTaskContext();
          llmSession.addToolResponse(toolResponse);
          nextLoop();
        }
      } catch (Exception e) {
        completionCallback.error(e);
      }
    });
  }

  private void updateTaskContextForHITLByLLM() {
    taskContext.setConversationState(new ConversationState(INPUT_REQUIRED, new AdditionalInputRequired(LLM, currentLoop)));
    safeUpsertTaskContext();
  }

  private void safeUpsertTaskContext() {
    try {
      conversationService.upsert(taskContext, broker);
    } catch (Exception e) {
      LOGGER.error("Exception found updating TaskContext", e);
    }
  }

  private void resetConsecutiveErrors() {
    consecutiveErrors = 0;
  }

  private boolean maxConsecutiveErrorsReached() {
    if (++consecutiveErrors >= maxConsecutiveErrors) {
      reasoningLog("Maximum consecutive errors reached. Task '%s' will fail".formatted(taskContext.getTaskId()));
      return true;
    }

    return false;
  }

  private void setMDC() {
    MDC.put(AGENT_MDC_KEY, broker.getConfigName());
    MDC.put(ITERATION_MDC_KEY, String.valueOf(currentLoop));
    MDC.put(TASK_ID_MDC_KEY, taskId);
    MDC.put(CONTEXT_ID_MDC_KEY, contextId);
    MDC.put(API_INSTANCE_ID_MDC_KEY, apiInstanceId);
  }

  private class CompletionCallbackDecorator<T, A> implements CompletionCallback<T, A> {

    private final CompletionCallback<T, A> delegate;
    private final Map<String, String> originalMdcContext;

    public CompletionCallbackDecorator(CompletionCallback<T, A> delegate) {
      this.delegate = delegate;
      originalMdcContext = MDC.getCopyOfContextMap();
    }

    @Override
    public void success(Result<T, A> result) {
      try {
        delegate.success(result);
      } finally {
        restoreMdcContext();
      }
    }

    @Override
    public void error(Throwable throwable) {
      try {
        delegate.error(throwable);
      } finally {
        restoreMdcContext();
      }
    }

    private void restoreMdcContext() {
      MDC.setContextMap(originalMdcContext);
    }
  }
}
