/*
 * Decompiled with CFR 0.152.
 */
package com.mulesoft.modules.agent.broker.internal.operation.loop;

import com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes;
import com.mulesoft.modules.agent.broker.internal.extension.AgentsBroker;
import com.mulesoft.modules.agent.broker.internal.extension.connection.LLMClient;
import com.mulesoft.modules.agent.broker.internal.llm.LLMOrchestrationRequest;
import com.mulesoft.modules.agent.broker.internal.log.SecureLogger;
import com.mulesoft.modules.agent.broker.internal.prompt.PromptBuilder;
import com.mulesoft.modules.agent.broker.internal.state.ConversationService;
import com.mulesoft.modules.agent.broker.internal.state.model.AdditionalInputRequired;
import com.mulesoft.modules.agent.broker.internal.state.model.AgentToolContext;
import com.mulesoft.modules.agent.broker.internal.state.model.Conversation;
import com.mulesoft.modules.agent.broker.internal.state.model.ConversationState;
import com.mulesoft.modules.agent.broker.internal.state.model.Iteration;
import com.mulesoft.modules.agent.broker.internal.state.model.LLMOutput;
import com.mulesoft.modules.agent.broker.internal.state.model.TaskContext;
import com.mulesoft.modules.agent.broker.internal.state.model.ToolExecution;
import com.mulesoft.modules.agent.broker.internal.tool.Tool;
import com.mulesoft.modules.agent.broker.internal.tool.ToolRequest;
import com.mulesoft.modules.agent.broker.internal.tool.ToolResponse;
import com.mulesoft.modules.agent.broker.internal.tool.ToolType;
import com.mulesoft.modules.agent.broker.internal.tool.a2a.A2AService;
import com.mulesoft.modules.agent.broker.internal.tool.a2a.A2AToolResponse;
import com.mulesoft.modules.agent.broker.internal.util.ConcurrencyUtils;
import com.mulesoft.modules.agent.broker.internal.util.ExceptionUtils;
import io.a2a.spec.Message;
import io.a2a.spec.Task;
import io.a2a.spec.TaskState;
import io.a2a.spec.TaskStatus;
import io.a2a.spec.TextPart;
import java.time.OffsetDateTime;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.Executor;
import java.util.function.Consumer;
import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.api.i18n.I18nMessageFactory;
import org.mule.runtime.api.scheduler.SchedulerService;
import org.mule.runtime.api.store.ObjectStoreException;
import org.mule.runtime.core.api.util.StringUtils;
import org.mule.runtime.core.api.util.func.CheckedRunnable;
import org.mule.runtime.extension.api.client.ExtensionsClient;
import org.mule.runtime.extension.api.error.ErrorTypeDefinition;
import org.mule.runtime.extension.api.exception.ModuleException;
import org.mule.runtime.extension.api.runtime.operation.Result;
import org.mule.runtime.extension.api.runtime.process.CompletionCallback;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

class Loop {
    private static final Logger LOGGER = LoggerFactory.getLogger(Loop.class);
    private static final String AGENT_MDC_KEY = "agent";
    private static final String ITERATION_MDC_KEY = "iteration";
    private static final String TASK_ID_MDC_KEY = "taskId";
    private static final String CONTEXT_ID_MDC_KEY = "contextId";
    private static final String API_INSTANCE_ID_MDC_KEY = "apiInstanceId";
    private final PromptBuilder promptBuilder;
    private final LLMClient llmClient;
    private final Map<String, Tool> toolExecutors;
    private final String taskId;
    private final String contextId;
    private final Message message;
    private final String apiInstanceId;
    private final int maxLoops;
    private final int maxConsecutiveErrors;
    private final ConversationService conversationService;
    private final A2AService a2aService;
    private final AgentsBroker broker;
    private final ExtensionsClient extensionsClient;
    private final SchedulerService schedulerService;
    private final CompletionCallback<String, Void> completionCallback;
    private int currentLoop = 0;
    private int consecutiveErrors = 0;
    private StringBuilder conversationHistory = new StringBuilder();
    private Conversation conversation;
    private TaskContext taskContext;

    public Loop(PromptBuilder promptBuilder, LLMClient llmClient, Map<String, Tool> toolExecutors, String taskId, String contextId, Message message, String apiInstanceId, int maxLoops, int maxConsecutiveErrors, ConversationService conversationService, A2AService a2aService, AgentsBroker broker, ExtensionsClient extensionsClient, SchedulerService schedulerService, CompletionCallback<String, Void> completionCallback) {
        this.promptBuilder = promptBuilder;
        this.llmClient = llmClient;
        this.toolExecutors = toolExecutors;
        this.taskId = taskId;
        this.contextId = contextId;
        this.message = message;
        this.apiInstanceId = apiInstanceId;
        this.maxLoops = maxLoops;
        this.maxConsecutiveErrors = maxConsecutiveErrors;
        this.schedulerService = schedulerService;
        this.completionCallback = new CompletionCallbackDecorator<String, Void>(completionCallback);
        this.broker = broker;
        this.extensionsClient = extensionsClient;
        this.conversationService = conversationService;
        this.a2aService = a2aService;
    }

    public void start() {
        this.setMDC();
        CheckedRunnable nextAction = (CheckedRunnable)this.conversationService.synchronizedConversation(this.contextId, this.broker, () -> {
            try {
                this.fetchContext();
            }
            catch (Exception e) {
                this.completionCallback.error((Throwable)new ModuleException("Error recovering context %s: %s".formatted(this.message.getContextId(), e.getMessage()), (ErrorTypeDefinition)BrokerErrorTypes.A2A, (Throwable)e));
                return null;
            }
            ConversationState conversationState = this.taskContext.getConversationState();
            AdditionalInputRequired additionalInputRequired = conversationState.getAdditionalInputRequired();
            try {
                if (additionalInputRequired != null) {
                    conversationState.setTaskState(TaskState.WORKING);
                    conversationState.setAdditionalInputRequired(null);
                    this.safeUpsertTaskContext();
                    if (additionalInputRequired.getRequesterType() == AdditionalInputRequired.RequesterType.A2A_TOOL) {
                        return this.resumeHITLFromA2ATool();
                    }
                    if (additionalInputRequired.getRequesterType() == AdditionalInputRequired.RequesterType.LLM) {
                        return this.resumeHITLByLLM();
                    }
                }
            }
            catch (Exception e) {
                this.completionCallback.error((Throwable)e);
                return null;
            }
            return this::nextLoop;
        });
        if (nextAction != null) {
            nextAction.run();
        }
    }

    private void fetchContext() {
        this.conversation = this.conversationService.getOrCreateConversation(this.message, this.contextId, this.broker);
        this.taskContext = this.conversationService.getTaskFor(this.message, this.conversation, this.taskId, this.broker);
    }

    private void nextLoop() {
        try {
            if (++this.currentLoop > this.maxLoops) {
                this.completeWithErrorAndUpdateState((Exception)new ModuleException("Maximum loops achieved without completing goal", (ErrorTypeDefinition)BrokerErrorTypes.MAX_LOOPS));
                return;
            }
            this.setMDC();
            this.promptBuilder.setCurrentLoopIteration(this.currentLoop).setConversationHistory(this.conversationHistory.toString());
            this.taskContext.setCurrentIteration(this.currentLoop);
            Iteration iteration = new Iteration();
            iteration.setNumber(this.currentLoop);
            this.taskContext.addIteration(iteration);
            ConcurrencyUtils.completeAsyncInTCCL(this.llmClient.reasonNextStep(new LLMOrchestrationRequest(this.promptBuilder.build())), (Executor)this.schedulerService.ioScheduler(), (llmOutput, llmException) -> {
                this.setMDC();
                if (llmException != null) {
                    llmException = ExceptionUtils.unwrap(llmException);
                    SecureLogger.secureLogger().debug("LLM call failed", llmException);
                    this.completeWithErrorAndUpdateState((Exception)new ModuleException("Cannot complete task due to temporal issue with the reasoning engine", (ErrorTypeDefinition)BrokerErrorTypes.LLM_ERROR, llmException));
                    return;
                }
                iteration.setLlmOutput((LLMOutput)llmOutput);
                if (llmOutput.isGoalComplete()) {
                    this.taskContext.setConversationState(new ConversationState(TaskState.COMPLETED, null));
                    this.safeUpsertTaskContext();
                    this.completionCallback.success(this.a2aService.completedJsonResult(this.conversation, this.taskContext, llmOutput.getResult()));
                }
                if (llmOutput.isAdditionalInputRequired()) {
                    try {
                        this.updateTaskContextForHITLByLLM(this.taskContext, iteration, (LLMOutput)llmOutput);
                        this.completionCallback.success(this.a2aService.inputRequiredJsonResult(this.conversation, this.taskContext, llmOutput.getResult()));
                    }
                    catch (Exception e) {
                        this.completeWithErrorAndUpdateState((Exception)new MuleRuntimeException(I18nMessageFactory.createStaticMessage((String)("Could not update context when requesting additional data from client: " + e.getMessage())), (Throwable)e));
                    }
                    return;
                }
                String toolToCall = llmOutput.getToolToCall();
                if (!llmOutput.isGoalComplete() && this.noToolSelected(toolToCall)) {
                    this.addToHistoryAndLog("LLM didn't select any tool to call next, the goal is not yet complete and no addition data requested. Re-evaluate and either select a tool or request additional data.", "Reasoning: " + llmOutput.getReasoning());
                    if (this.maxConsecutiveErrorsReached()) {
                        this.completeWithErrorAndUpdateState((Exception)new ModuleException(I18nMessageFactory.createStaticMessage((String)"LLM couldn't determine next step to complete the task"), (ErrorTypeDefinition)BrokerErrorTypes.REASONING_ERROR));
                    } else {
                        this.nextLoop();
                    }
                    return;
                }
                String toolInput = llmOutput.getToolInput();
                String reasoning = llmOutput.getReasoning();
                Tool toolExecutor = this.toolExecutors.get(toolToCall);
                if (toolExecutor == null) {
                    this.addToHistoryAndLog("LLM selected tool '" + toolToCall + "' but it doesn't exist. Reevaluate and make sure to select a tool from the provided list", new String[0]);
                    if (this.maxConsecutiveErrorsReached()) {
                        LOGGER.error("LLM selected a tool that doesn't exist: {}. Max consecutive errors limit reached. Task will fail", (Object)toolToCall);
                        this.completeWithErrorAndUpdateState((Exception)new ModuleException("Could not determine next steps for task completion", (ErrorTypeDefinition)BrokerErrorTypes.REASONING_ERROR));
                    } else {
                        this.nextLoop();
                    }
                    return;
                }
                this.addToHistoryAndLog("LLM selected tool " + toolToCall, "Reasoning was: " + this.nullSafe(reasoning) + ". Tool Input: " + toolInput);
                ToolRequest toolRequest = new ToolRequest(toolToCall, toolInput, reasoning, this.currentLoop, this.taskContext, this.apiInstanceId);
                this.executeTool(toolExecutor, toolRequest, toolToCall, toolResponse -> {
                    try {
                        this.resetConsecutiveErrors();
                        if (this.handleToolResponse(this.taskContext, (LLMOutput)llmOutput, toolToCall, (ToolResponse)toolResponse)) {
                            this.nextLoop();
                        }
                    }
                    catch (Exception e) {
                        this.completeWithErrorAndUpdateState(e);
                    }
                });
            });
        }
        catch (Exception e) {
            this.completionCallback.error((Throwable)e);
        }
    }

    private boolean noToolSelected(String toolName) {
        return StringUtils.isBlank((String)toolName) || "null".equalsIgnoreCase(toolName);
    }

    private <T extends ToolResponse> void executeTool(Tool executor, ToolRequest toolRequest, String toolName, Consumer<T> onSuccess) {
        ConcurrencyUtils.completeAsyncInTCCL(executor.execute(toolRequest, this.extensionsClient), (Executor)this.schedulerService.ioScheduler(), (toolResult, toolException) -> {
            this.setMDC();
            if (toolException != null) {
                toolException = ExceptionUtils.unwrap(toolException);
                this.addToHistoryAndLog("Execution of tool " + toolName + " failed. Analyze the error and determine whether the tool input needs to be adjusted, or if a different tool can be used instead", "Error message was: " + toolException.getMessage());
                if (this.maxConsecutiveErrorsReached()) {
                    this.addToHistoryAndLog("Tool '%s' failed and maximum consecutive errors reached. Task '%s' will fail".formatted(toolName, this.taskContext.getTaskId()), "Error message was: " + toolException.getMessage());
                    this.completeWithErrorAndUpdateState((Exception)new ModuleException("Could not complete next step due to an error. Task failed.", (ErrorTypeDefinition)BrokerErrorTypes.TOOL_ERROR, toolException));
                } else {
                    this.nextLoop();
                }
            } else {
                if (StringUtils.isBlank((String)toolResult.getResult())) {
                    this.addToHistoryAndLog("Execution of tool " + toolName + " came back with no result. Determine whether the tool input needs to be adjusted, or if a different tool can be used instead", new String[0]);
                    if (this.maxConsecutiveErrorsReached()) {
                        this.completeWithErrorAndUpdateState((Exception)new ModuleException("Could not complete next step due to an error. Task failed.", (ErrorTypeDefinition)BrokerErrorTypes.TOOL_ERROR));
                    } else {
                        this.nextLoop();
                    }
                    return;
                }
                onSuccess.accept(toolResult);
            }
        });
    }

    private boolean handleToolResponse(TaskContext taskContext, LLMOutput llmOutput, String toolName, ToolResponse toolResponse) {
        if (toolResponse.getToolType() == ToolType.A2A) {
            return this.handleA2AToolResponse(taskContext, toolName, llmOutput.getToolInput(), (A2AToolResponse)toolResponse);
        }
        this.handleStandardToolResponse(taskContext, llmOutput, toolName, toolResponse);
        return true;
    }

    private void completeWithErrorAndUpdateState(Exception exception) {
        try {
            this.taskContext.setConversationState(new ConversationState(TaskState.FAILED, exception.getMessage()));
            this.safeUpsertTaskContext();
        }
        finally {
            try {
                this.completionCallback.success(this.a2aService.asJsonResult(new Task.Builder().id(this.taskContext.getTaskId()).contextId(this.conversation.getId()).status(new TaskStatus(TaskState.FAILED, new Message.Builder().messageId(UUID.randomUUID().toString()).taskId(this.taskContext.getTaskId()).role(Message.Role.AGENT).parts(List.of(new TextPart(exception.getMessage()))).build(), OffsetDateTime.now())).build()));
            }
            catch (Exception e) {
                this.completionCallback.error((Throwable)e);
            }
        }
    }

    private boolean handleA2AToolResponse(TaskContext taskContext, String toolName, String toolInput, A2AToolResponse toolResponse) {
        boolean inputRequired = toolResponse.isInputRequired();
        String toolOutput = toolResponse.getResult();
        AgentToolContext toolContext = taskContext.getAgentToolContext(toolName);
        toolContext.setTaskState(toolResponse.getState());
        if (toolResponse.getContextId() != null) {
            toolContext.setContextId(toolResponse.getContextId());
        }
        if (toolResponse.getTaskId() != null) {
            toolContext.setTaskId(toolResponse.getTaskId());
        }
        this.addToolResponseToHistory(toolName, toolOutput);
        if (inputRequired) {
            this.updateTaskContextForToolCall(taskContext, toolName, toolInput, toolOutput, true);
            this.safeUpsertTaskContext();
            this.completionCallback.success(this.a2aService.inputRequiredJsonResult(this.conversation, taskContext, toolOutput));
            return false;
        }
        return true;
    }

    private void addToolResponseToHistory(String toolName, String toolOutput) {
        this.addToHistoryAndLog("Executed tool " + toolName, "Output was: " + toolOutput);
    }

    private void handleStandardToolResponse(TaskContext taskContext, LLMOutput llmOutput, String toolName, ToolResponse toolResponse) {
        this.addToolResponseToHistory(toolName, toolResponse.getResult());
        this.updateTaskContextForToolCall(taskContext, llmOutput.getToolToCall(), llmOutput.getReasoning(), llmOutput.getToolInput(), false);
    }

    private void addToHistoryAndLog(String message, String ... secureAppendix) {
        LOGGER.info(message);
        this.conversationHistory.append(message);
        boolean addNewLine = message.endsWith("\n");
        if (secureAppendix != null) {
            for (String s : secureAppendix) {
                SecureLogger.secureLogger().debug(s);
                this.conversationHistory.append(s);
                addNewLine = addNewLine || s.endsWith("\n");
            }
        }
        if (addNewLine) {
            this.conversationHistory.append('\n');
        }
    }

    private CheckedRunnable resumeHITLByLLM() {
        this.conversationHistory = new StringBuilder(this.taskContext.getConversationHistory());
        this.currentLoop = this.taskContext.getCurrentIteration() - 1;
        LOGGER.info("Resuming conversation {} from iteration {}", (Object)this.conversation.getId(), (Object)this.taskContext.getCurrentIteration());
        return this::nextLoop;
    }

    private CheckedRunnable resumeHITLFromA2ATool() {
        this.conversationHistory = new StringBuilder(this.taskContext.getConversationHistory());
        Iteration lastIteration = this.taskContext.getLastIteration();
        ToolExecution waitingToolResult = lastIteration.getToolExecution();
        String waitingToolName = waitingToolResult.getName();
        String toolInput = this.promptBuilder.getUserPrompt();
        String reasoning = waitingToolResult.getOutput();
        Tool tool = this.toolExecutors.get(waitingToolName);
        if (tool == null) {
            throw new ModuleException("Could not find A2A agent config to resume the tool - " + String.valueOf(waitingToolResult), (ErrorTypeDefinition)BrokerErrorTypes.A2A);
        }
        this.currentLoop = this.taskContext.getCurrentIteration() - 1;
        ToolRequest toolRequest = new ToolRequest(waitingToolName, toolInput, reasoning, this.currentLoop, this.taskContext, this.apiInstanceId);
        return () -> this.executeTool(tool, toolRequest, waitingToolName, toolResult -> {
            try {
                if (this.handleA2AToolResponse(this.taskContext, waitingToolName, toolInput, (A2AToolResponse)toolResult)) {
                    this.nextLoop();
                }
            }
            catch (Exception e) {
                this.completionCallback.error((Throwable)e);
            }
        });
    }

    private void updateTaskContextForHITLByLLM(TaskContext taskContext, Iteration iteration, LLMOutput response) throws ObjectStoreException {
        ConversationState conversationState = taskContext.getConversationState();
        conversationState.setAdditionalInputRequired(new AdditionalInputRequired(AdditionalInputRequired.RequesterType.LLM, response.getToolToCall()));
        conversationState.setTaskState(TaskState.INPUT_REQUIRED);
        iteration.setLlmOutput(response);
        iteration.setToolExecution(null);
        this.safeUpsertTaskContext();
    }

    private void safeUpsertTaskContext() {
        try {
            this.conversationService.upsert(this.taskContext, this.broker);
        }
        catch (Exception e) {
            LOGGER.error("Exception found updating TaskContext", (Throwable)e);
        }
    }

    private void updateTaskContextForToolCall(TaskContext taskContext, String toolName, String toolInput, String toolResponse, boolean requiresAdditionalInput) {
        ToolExecution toolExecution = new ToolExecution(toolName, toolInput, toolResponse);
        taskContext.getLastIteration().setToolExecution(toolExecution);
        if (requiresAdditionalInput) {
            ConversationState conversationState = taskContext.getConversationState();
            conversationState.setAdditionalInputRequired(new AdditionalInputRequired(AdditionalInputRequired.RequesterType.A2A_TOOL, toolName));
            conversationState.setTaskState(TaskState.INPUT_REQUIRED);
        }
        taskContext.setConversationHistory(this.conversationHistory.toString());
    }

    private void resetConsecutiveErrors() {
        this.consecutiveErrors = 0;
    }

    private boolean maxConsecutiveErrorsReached() {
        return ++this.consecutiveErrors >= this.maxConsecutiveErrors;
    }

    private void setMDC() {
        MDC.put((String)AGENT_MDC_KEY, (String)this.broker.getConfigName());
        MDC.put((String)ITERATION_MDC_KEY, (String)String.valueOf(this.currentLoop));
        MDC.put((String)TASK_ID_MDC_KEY, (String)this.taskId);
        MDC.put((String)CONTEXT_ID_MDC_KEY, (String)this.contextId);
        MDC.put((String)API_INSTANCE_ID_MDC_KEY, (String)this.apiInstanceId);
    }

    private String nullSafe(String value) {
        return value == null ? "" : value;
    }

    private class CompletionCallbackDecorator<T, A>
    implements CompletionCallback<T, A> {
        private final CompletionCallback<T, A> delegate;
        private final Map<String, String> originalMdcContext;

        public CompletionCallbackDecorator(CompletionCallback<T, A> delegate) {
            this.delegate = delegate;
            this.originalMdcContext = MDC.getCopyOfContextMap();
        }

        public void success(Result<T, A> result) {
            try {
                this.delegate.success(result);
            }
            finally {
                this.restoreMdcContext();
            }
        }

        public void error(Throwable throwable) {
            try {
                this.delegate.error(throwable);
            }
            finally {
                this.restoreMdcContext();
            }
        }

        private void restoreMdcContext() {
            MDC.setContextMap(this.originalMdcContext);
        }
    }
}

