/*
 * Decompiled with CFR 0.152.
 */
package com.mulesoft.modules.agent.conductor.internal.operation.loop;

import com.mulesoft.modules.agent.conductor.api.model.Orchestration;
import com.mulesoft.modules.agent.conductor.internal.error.ConductorErrorTypes;
import com.mulesoft.modules.agent.conductor.internal.llm.LLMOrchestrationRequest;
import com.mulesoft.modules.agent.conductor.internal.llm.client.LLMClient;
import com.mulesoft.modules.agent.conductor.internal.prompt.PromptBuilder;
import com.mulesoft.modules.agent.conductor.internal.state.TaskContextService;
import com.mulesoft.modules.agent.conductor.internal.state.model.AdditionalInputRequired;
import com.mulesoft.modules.agent.conductor.internal.state.model.AgentToolContext;
import com.mulesoft.modules.agent.conductor.internal.state.model.ConversationState;
import com.mulesoft.modules.agent.conductor.internal.state.model.Iteration;
import com.mulesoft.modules.agent.conductor.internal.state.model.LLMOutput;
import com.mulesoft.modules.agent.conductor.internal.state.model.TaskContext;
import com.mulesoft.modules.agent.conductor.internal.state.model.ToolExecution;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolHandler;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolRequest;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolResponse;
import com.mulesoft.modules.agent.conductor.internal.tool.ToolType;
import com.mulesoft.modules.agent.conductor.internal.tool.a2a.A2AToolResponse;
import com.mulesoft.modules.agent.conductor.internal.tool.a2a.A2AToolService;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.api.i18n.I18nMessageFactory;
import org.mule.runtime.api.store.ObjectStore;
import org.mule.runtime.api.store.ObjectStoreException;
import org.mule.runtime.core.api.util.StringUtils;
import org.mule.runtime.core.api.util.func.CheckedFunction;
import org.mule.runtime.extension.api.error.ErrorTypeDefinition;
import org.mule.runtime.extension.api.exception.ModuleException;
import org.mule.runtime.extension.api.runtime.operation.Result;
import org.mule.runtime.extension.api.runtime.process.CompletionCallback;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

class Loop {
    private static final Logger LOGGER = LoggerFactory.getLogger(Loop.class);
    private static final String AGENT_MDC_KEY = "agent";
    private static final String ITERATION_MDC_KEY = "iteration";
    private static final String TASK_ID_MDC_KEY = "taskId";
    private static final String CONTEXT_ID_MDC_KEY = "contextId";
    private final PromptBuilder promptBuilder;
    private final LLMClient llmClient;
    private final A2AToolService a2aToolService;
    private final Map<String, ToolHandler> toolHandlers;
    private final String taskId;
    private final String contextId;
    private final int maxLoops;
    private final int maxConsecutiveErrors;
    private final TaskContextService taskContextService;
    private final ObjectStore<TaskContext> taskContextStore;
    private final String configRef;
    private final CompletionCallback<Orchestration, Void> completionCallback;
    private int currentLoop = 0;
    private int consecutiveErrors = 0;
    private StringBuilder conversationHistory = new StringBuilder("");
    private TaskContext taskContext;

    public Loop(PromptBuilder promptBuilder, LLMClient llmClient, A2AToolService a2aToolService, Map<String, ToolHandler> toolHandlers, String taskId, String contextId, int maxLoops, int maxConsecutiveErrors, TaskContextService taskContextService, ObjectStore<TaskContext> taskContextStore, String configRef, CompletionCallback<Orchestration, Void> completionCallback) {
        this.promptBuilder = promptBuilder;
        this.llmClient = llmClient;
        this.a2aToolService = a2aToolService;
        this.toolHandlers = toolHandlers;
        this.taskId = taskId;
        this.contextId = contextId;
        this.maxLoops = maxLoops;
        this.maxConsecutiveErrors = maxConsecutiveErrors;
        this.completionCallback = new CompletionCallbackDecorator<Orchestration, Void>(completionCallback);
        this.taskContextStore = taskContextStore;
        this.configRef = configRef;
        this.taskContextService = taskContextService;
    }

    public void start() {
        this.setMDC();
        try {
            this.taskContext = this.getOrCreateTaskContext();
        }
        catch (Exception e) {
            this.completionCallback.error((Throwable)new MuleRuntimeException(I18nMessageFactory.createStaticMessage((String)"Error while fetching the task context for contextId %s: %s".formatted(this.contextId, e.getMessage())), (Throwable)e));
            return;
        }
        try {
            ConversationState conversationState = this.taskContext.getConversationState();
            if (conversationState.isCompleted() || conversationState.isFailed()) {
                this.taskContext.reset();
                this.nextLoop();
            } else {
                AdditionalInputRequired additionalInputRequired = conversationState.getAdditionalInputRequired();
                if (additionalInputRequired != null) {
                    if (additionalInputRequired.getRequesterType() == AdditionalInputRequired.RequesterType.A2A_TOOL) {
                        this.resumeHITLFromA2ATool();
                    } else if (additionalInputRequired.getRequesterType() == AdditionalInputRequired.RequesterType.LLM) {
                        this.resumeHITLByLLM();
                    }
                } else {
                    this.nextLoop();
                }
            }
        }
        catch (Exception e) {
            this.completionCallback.error((Throwable)e);
        }
    }

    private void nextLoop() {
        try {
            if (++this.currentLoop > this.maxLoops) {
                this.completionCallback.error((Throwable)new ModuleException("Maximum loops achieved without completing goal", (ErrorTypeDefinition)ConductorErrorTypes.MAX_LOOPS));
                return;
            }
            this.setMDC();
            this.promptBuilder.setCurrentLoopIteration(this.currentLoop).setConversationHistory(this.conversationHistory.toString());
            this.taskContext.setCurrentIteration(this.currentLoop);
            Iteration iteration = new Iteration();
            iteration.setNumber(this.currentLoop);
            this.taskContext.addIteration(iteration);
            this.llmClient.reasonNextStep(new LLMOrchestrationRequest(this.promptBuilder.build())).whenComplete((llmOutput, llmException) -> {
                this.setMDC();
                if (llmException != null) {
                    this.completeWithErrorAndUpdateState((Exception)new ModuleException("LLM call failed: " + llmException.getMessage(), (ErrorTypeDefinition)ConductorErrorTypes.LLM_ERROR, llmException));
                    return;
                }
                iteration.setLlmOutput((LLMOutput)llmOutput);
                if (llmOutput.isGoalComplete()) {
                    this.taskContext.setCurrentState(new ConversationState(true, false, null));
                    this.safeUpsertTaskContext();
                    this.completionCallback.success(Result.builder().output((Object)new Orchestration(this.taskId, this.contextId, llmOutput.getResult(), llmOutput.isGoalComplete(), llmOutput.isAdditionalInputNeeded(), this.currentLoop, llmOutput.getReasoning())).build());
                    return;
                }
                if (llmOutput.isAdditionalInputNeeded()) {
                    try {
                        this.updateTaskContextForHITLByLLM(this.taskContext, iteration, (LLMOutput)llmOutput);
                        this.completionCallback.success(Result.builder().output((Object)new Orchestration(this.taskId, this.contextId, llmOutput.getResult(), false, true, this.currentLoop, llmOutput.getReasoning())).build());
                    }
                    catch (Exception e) {
                        this.completeWithErrorAndUpdateState((Exception)new MuleRuntimeException(I18nMessageFactory.createStaticMessage((String)("Could not update context when requesting additional data from client: " + e.getMessage())), (Throwable)e));
                    }
                    return;
                }
                String toolToCall = llmOutput.getToolToCall();
                if (!llmOutput.isGoalComplete() && StringUtils.isBlank((String)toolToCall)) {
                    this.addToHistory("LLM didn't select any tool to call next, the goal is not yet complete and no addition data requested. Re-evaluate and either select a tool, requested additional data or mark the goal as completed.", new String[0]);
                    if (this.maxConsecutiveErrorsReached()) {
                        this.completeWithErrorAndUpdateState((Exception)new ModuleException(I18nMessageFactory.createStaticMessage((String)("LLM couldn't select any tool to proceed forward. Reasoning: " + llmOutput.getReasoning())), (ErrorTypeDefinition)ConductorErrorTypes.REASONING_ERROR));
                    } else {
                        this.nextLoop();
                    }
                    return;
                }
                String toolInput = llmOutput.getToolInput();
                String reasoning = llmOutput.getReasoning();
                ToolHandler tool = this.toolHandlers.get(toolToCall);
                if (tool == null) {
                    this.addToHistory("LLM selected tool '" + toolToCall + "' but it doesn't exist. Reevaluate and make sure to select a tool from the provided list", new String[0]);
                    if (this.maxConsecutiveErrorsReached()) {
                        this.completeWithErrorAndUpdateState((Exception)new ModuleException(I18nMessageFactory.createStaticMessage((String)("LLM selected a tool that doesn't exist: " + toolToCall)), (ErrorTypeDefinition)ConductorErrorTypes.REASONING_ERROR));
                    } else {
                        this.nextLoop();
                    }
                    return;
                }
                this.addToHistory("LLM selected tool " + toolToCall + ". Reasoning was: " + this.nullSafe(reasoning) + ".", "Tool Input: " + toolInput);
                ToolRequest toolRequest = new ToolRequest(toolToCall, toolInput, reasoning, this.currentLoop, this.taskContext);
                this.applyToolHandler(tool.getHandler(), toolRequest, toolToCall, toolResponse -> {
                    try {
                        this.resetConsecutiveErrors();
                        if (this.handleToolResponse(this.taskContext, (LLMOutput)llmOutput, toolToCall, (ToolResponse)toolResponse)) {
                            this.nextLoop();
                        }
                    }
                    catch (Exception e) {
                        this.completeWithErrorAndUpdateState(e);
                    }
                });
            });
        }
        catch (Exception e) {
            this.completionCallback.error((Throwable)e);
        }
    }

    private <T> void applyToolHandler(CheckedFunction<ToolRequest, CompletableFuture<T>> handler, ToolRequest toolRequest, String toolName, Consumer<T> onSuccess) {
        ((CompletableFuture)handler.apply((Object)toolRequest)).whenComplete((toolResult, toolException) -> {
            this.setMDC();
            if (toolException != null) {
                this.addToHistory("Execution of tool " + toolName + " failed. Error message was: " + toolException.getMessage() + " . Analyze the error and determine whether the tool input needs to be adjusted, or if a different tool can be used instead", new String[0]);
                if (this.maxConsecutiveErrorsReached()) {
                    this.completionCallback.error((Throwable)new ModuleException("Error executing tool %s: %s".formatted(toolName, toolException.getMessage()), (ErrorTypeDefinition)ConductorErrorTypes.TOOL_ERROR, toolException));
                } else {
                    this.nextLoop();
                }
            } else {
                onSuccess.accept(toolResult);
            }
        });
    }

    private boolean handleToolResponse(TaskContext taskContext, LLMOutput llmOutput, String toolName, ToolResponse toolResponse) {
        if (toolResponse.getToolType() == ToolType.A2A) {
            return this.handleA2AToolResponse(taskContext, toolName, llmOutput.getToolInput(), (A2AToolResponse)toolResponse);
        }
        this.handleStandardToolResponse(taskContext, llmOutput, toolName, toolResponse);
        return true;
    }

    private void completeWithErrorAndUpdateState(Exception exception) {
        try {
            this.taskContext.setCurrentState(new ConversationState(false, true, exception.getMessage()));
            this.safeUpsertTaskContext();
        }
        finally {
            this.completionCallback.error((Throwable)exception);
        }
    }

    private boolean handleA2AToolResponse(TaskContext taskContext, String toolName, String toolInput, A2AToolResponse toolResponse) {
        boolean inputRequired = toolResponse.isInputRequired();
        String toolOutput = toolResponse.getResult();
        AgentToolContext toolContext = taskContext.getAgentToolContext(toolName);
        toolContext.setTaskId(taskContext.getTaskId());
        if (toolResponse.getContextId() != null) {
            toolContext.setContextId(toolResponse.getContextId());
        }
        this.addToolResponseToHistory(toolName, toolOutput);
        if (inputRequired) {
            this.updateTaskContextForToolCall(taskContext, toolName, toolInput, toolOutput, true);
            this.safeUpsertTaskContext();
            this.completionCallback.success(Result.builder().output((Object)new Orchestration(this.taskId, this.contextId, toolOutput, false, true, this.currentLoop, "Agent " + this.configRef + " requires additional input to proceed")).build());
            return false;
        }
        return true;
    }

    private void addToolResponseToHistory(String toolName, String toolOutput) {
        this.addToHistory("Executed tool " + toolName + ".", "Output was: " + toolOutput);
    }

    private void handleStandardToolResponse(TaskContext taskContext, LLMOutput llmOutput, String toolName, ToolResponse toolResponse) {
        this.addToolResponseToHistory(toolName, toolResponse.getResult());
        this.updateTaskContextForToolCall(taskContext, llmOutput.getToolToCall(), llmOutput.getReasoning(), llmOutput.getToolInput(), false);
    }

    private void addToHistory(String message, String ... additional) {
        LOGGER.info(message);
        this.conversationHistory.append(message);
        boolean addNewLine = message.endsWith("\n");
        if (additional != null) {
            for (String s : additional) {
                LOGGER.debug(s);
                this.conversationHistory.append(s);
                addNewLine = addNewLine || s.endsWith("\n");
            }
        }
        if (addNewLine) {
            this.conversationHistory.append('\n');
        }
    }

    private void resumeHITLByLLM() {
        this.conversationHistory = new StringBuilder(this.taskContext.getConversationHistory());
        this.currentLoop = this.taskContext.getCurrentIteration() - 1;
        LOGGER.info("Resuming the conversation from iteration {}", (Object)this.taskContext.getCurrentIteration());
        this.nextLoop();
    }

    private void resumeHITLFromA2ATool() {
        Iteration lastIteration = this.taskContext.getLastIteration();
        ToolExecution waitingTool = lastIteration.getToolExecution();
        String taskId = this.taskContext.getTaskId();
        String waitingToolName = waitingTool.getName();
        String toolInput = this.promptBuilder.getUserPrompt();
        String reasoning = waitingTool.getOutput();
        String a2AClientConfigRef = this.a2aToolService.getA2AClientConfigRef(waitingToolName);
        if (a2AClientConfigRef == null) {
            throw new MuleRuntimeException(I18nMessageFactory.createStaticMessage((String)("Could not find A2A agent config to resume the tool - " + String.valueOf(waitingTool))));
        }
        this.currentLoop = this.taskContext.getCurrentIteration() - 1;
        ToolRequest toolRequest = new ToolRequest(waitingToolName, toolInput, reasoning, this.currentLoop, this.taskContext);
        CheckedFunction<ToolRequest, CompletableFuture<ToolResponse>> a2aToolHandler = this.a2aToolService.createHandler(a2AClientConfigRef);
        this.applyToolHandler(a2aToolHandler, toolRequest, waitingToolName, toolResult -> {
            try {
                this.handleA2AToolResponse(this.taskContext, waitingToolName, toolInput, (A2AToolResponse)toolResult);
                this.nextLoop();
            }
            catch (Exception e) {
                this.completionCallback.error((Throwable)e);
            }
        });
    }

    private TaskContext getOrCreateTaskContext() throws ObjectStoreException {
        if (this.taskContext == null) {
            this.taskContext = this.taskContextService.getOrCreate(this.taskId, this.contextId, this.configRef, this.taskContextStore);
        }
        return this.taskContext;
    }

    private void updateTaskContextForHITLByLLM(TaskContext taskContext, Iteration iteration, LLMOutput response) throws ObjectStoreException {
        ConversationState conversationState = taskContext.getConversationState();
        conversationState.setAdditionalInputRequired(new AdditionalInputRequired(AdditionalInputRequired.RequesterType.LLM, response.getToolToCall()));
        conversationState.setCompleted(false);
        conversationState.setFailed(false);
        iteration.setLlmOutput(response);
        iteration.setToolExecution(null);
        this.taskContextService.upsertTaskContext(taskContext, this.configRef, this.taskContextStore);
    }

    private void safeUpsertTaskContext() {
        try {
            this.taskContextService.upsertTaskContext(this.taskContext, this.configRef, this.taskContextStore);
        }
        catch (Exception e) {
            LOGGER.error("Exception found updating TaskContext");
        }
    }

    private void updateTaskContextForToolCall(TaskContext taskContext, String toolName, String toolInput, String toolResponse, boolean requiresAdditionalInput) {
        ToolExecution toolExecution = new ToolExecution(toolName, toolInput, toolResponse);
        taskContext.getLastIteration().setToolExecution(toolExecution);
        if (requiresAdditionalInput) {
            ConversationState conversationState = taskContext.getConversationState();
            conversationState.setAdditionalInputRequired(new AdditionalInputRequired(AdditionalInputRequired.RequesterType.A2A_TOOL, toolName));
            conversationState.setCompleted(false);
            conversationState.setFailed(false);
        }
        taskContext.setConversationHistory(this.conversationHistory.toString());
    }

    private void resetConsecutiveErrors() {
        this.consecutiveErrors = 0;
    }

    private boolean maxConsecutiveErrorsReached() {
        return ++this.consecutiveErrors >= this.maxConsecutiveErrors;
    }

    private void setMDC() {
        MDC.put((String)AGENT_MDC_KEY, (String)this.configRef);
        MDC.put((String)ITERATION_MDC_KEY, (String)String.valueOf(this.currentLoop));
        MDC.put((String)TASK_ID_MDC_KEY, (String)this.taskId);
        MDC.put((String)CONTEXT_ID_MDC_KEY, (String)this.contextId);
    }

    private String nullSafe(String value) {
        return value == null ? "" : value;
    }

    private class CompletionCallbackDecorator<T, A>
    implements CompletionCallback<T, A> {
        private final CompletionCallback<T, A> delegate;
        private final Map<String, String> originalMdcContext;

        public CompletionCallbackDecorator(CompletionCallback<T, A> delegate) {
            this.delegate = delegate;
            this.originalMdcContext = MDC.getCopyOfContextMap();
        }

        public void success(Result<T, A> result) {
            try {
                this.delegate.success(result);
            }
            finally {
                this.restoreMdcContext();
            }
        }

        public void error(Throwable throwable) {
            try {
                this.delegate.error(throwable);
            }
            finally {
                this.restoreMdcContext();
            }
        }

        private void restoreMdcContext() {
            MDC.setContextMap(this.originalMdcContext);
        }
    }
}

