/*
 * Decompiled with CFR 0.152.
 */
package com.mulesoft.modules.agent.broker.internal.operation.loop;

import com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes;
import com.mulesoft.modules.agent.broker.internal.extension.AgentsBroker;
import com.mulesoft.modules.agent.broker.internal.extension.connection.LLMClient;
import com.mulesoft.modules.agent.broker.internal.log.SecureLogger;
import com.mulesoft.modules.agent.broker.internal.state.ConversationService;
import com.mulesoft.modules.agent.broker.internal.state.model.AdditionalInputRequired;
import com.mulesoft.modules.agent.broker.internal.state.model.AgentToolContext;
import com.mulesoft.modules.agent.broker.internal.state.model.Conversation;
import com.mulesoft.modules.agent.broker.internal.state.model.ConversationState;
import com.mulesoft.modules.agent.broker.internal.state.model.Iteration;
import com.mulesoft.modules.agent.broker.internal.state.model.LLMOutput;
import com.mulesoft.modules.agent.broker.internal.state.model.TaskContext;
import com.mulesoft.modules.agent.broker.internal.state.model.ToolSelection;
import com.mulesoft.modules.agent.broker.internal.tool.Tool;
import com.mulesoft.modules.agent.broker.internal.tool.ToolRequest;
import com.mulesoft.modules.agent.broker.internal.tool.ToolResponse;
import com.mulesoft.modules.agent.broker.internal.tool.ToolType;
import com.mulesoft.modules.agent.broker.internal.tool.a2a.A2AService;
import com.mulesoft.modules.agent.broker.internal.tool.a2a.A2AToolResponse;
import com.mulesoft.modules.agent.broker.internal.util.ConcurrencyUtils;
import com.mulesoft.modules.agent.broker.internal.util.ExceptionUtils;
import io.a2a.spec.Message;
import io.a2a.spec.Task;
import io.a2a.spec.TaskState;
import io.a2a.spec.TaskStatus;
import io.a2a.spec.TextPart;
import java.time.OffsetDateTime;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.Executor;
import java.util.function.Consumer;
import org.mule.runtime.api.scheduler.SchedulerService;
import org.mule.runtime.core.api.util.func.CheckedRunnable;
import org.mule.runtime.extension.api.client.ExtensionsClient;
import org.mule.runtime.extension.api.error.ErrorTypeDefinition;
import org.mule.runtime.extension.api.exception.ModuleException;
import org.mule.runtime.extension.api.runtime.operation.Result;
import org.mule.runtime.extension.api.runtime.process.CompletionCallback;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

class Loop {
    private static final Logger LOGGER = LoggerFactory.getLogger(Loop.class);
    private static final String AGENT_MDC_KEY = "agent";
    private static final String ITERATION_MDC_KEY = "iteration";
    private static final String TASK_ID_MDC_KEY = "taskId";
    private static final String CONTEXT_ID_MDC_KEY = "contextId";
    private static final String API_INSTANCE_ID_MDC_KEY = "apiInstanceId";
    private final LLMClient.LLMSession llmSession;
    private final Map<String, Tool> toolExecutors;
    private final String taskId;
    private final String contextId;
    private final Message message;
    private final String apiInstanceId;
    private final int maxLoops;
    private final int maxConsecutiveErrors;
    private final ConversationService conversationService;
    private final A2AService a2aService;
    private final AgentsBroker broker;
    private final ExtensionsClient extensionsClient;
    private final SchedulerService schedulerService;
    private final CompletionCallback<String, Void> completionCallback;
    private int currentLoop = 0;
    private int consecutiveErrors = 0;
    private Iteration iteration;
    private Conversation conversation;
    private TaskContext taskContext;

    Loop(LLMClient.LLMSession llmSession, Map<String, Tool> toolExecutors, String taskId, String contextId, Message message, String apiInstanceId, int maxLoops, int maxConsecutiveErrors, ConversationService conversationService, A2AService a2aService, AgentsBroker broker, ExtensionsClient extensionsClient, SchedulerService schedulerService, CompletionCallback<String, Void> completionCallback) {
        this.llmSession = llmSession;
        this.toolExecutors = toolExecutors;
        this.taskId = taskId;
        this.contextId = contextId;
        this.message = message;
        this.apiInstanceId = apiInstanceId;
        this.maxLoops = maxLoops;
        this.maxConsecutiveErrors = maxConsecutiveErrors;
        this.schedulerService = schedulerService;
        this.completionCallback = new CompletionCallbackDecorator<String, Void>(completionCallback);
        this.broker = broker;
        this.extensionsClient = extensionsClient;
        this.conversationService = conversationService;
        this.a2aService = a2aService;
    }

    public void start() {
        this.setMDC();
        CheckedRunnable nextAction = (CheckedRunnable)this.conversationService.synchronizedConversation(this.contextId, this.broker, () -> {
            ConversationState conversationState;
            try {
                this.fetchContext();
                this.taskContext.getIterations().forEach(this.llmSession::addIteration);
                conversationState = this.taskContext.getConversationState();
            }
            catch (Exception e) {
                this.completionCallback.error((Throwable)new ModuleException("Error recovering context %s: %s".formatted(this.message.getContextId(), e.getMessage()), (ErrorTypeDefinition)BrokerErrorTypes.A2A, (Throwable)e));
                return null;
            }
            AdditionalInputRequired additionalInputRequired = conversationState.getAdditionalInputRequired();
            try {
                if (additionalInputRequired != null) {
                    this.currentLoop = additionalInputRequired.getLoopNumber();
                    this.taskContext.setConversationState(new ConversationState(TaskState.WORKING));
                    if (additionalInputRequired.getRequesterType() == AdditionalInputRequired.RequesterType.A2A_TOOL) {
                        return this.resumeHITLFromA2ATool();
                    }
                    if (additionalInputRequired.getRequesterType() == AdditionalInputRequired.RequesterType.LLM) {
                        return this.resumeHITLByLLM(additionalInputRequired);
                    }
                }
            }
            catch (Exception e) {
                this.completionCallback.error((Throwable)e);
                return null;
            }
            return () -> {
                Iteration iteration = this.taskContext.newIteration();
                iteration.setUserPrompt(this.llmSession.getInitialRequest().getPrompt());
                this.nextLoop(iteration);
            };
        });
        if (nextAction != null) {
            nextAction.run();
        }
    }

    private void fetchContext() {
        this.conversation = this.conversationService.getOrCreateConversation(this.message, this.contextId, this.broker);
        this.taskContext = this.conversationService.getTaskFor(this.message, this.conversation, this.taskId, this.broker);
    }

    private void nextLoop() {
        this.nextLoop(this.taskContext.newIteration());
    }

    private void nextLoop(Iteration iteration) {
        try {
            if (++this.currentLoop > this.maxLoops) {
                this.updateStateAndCompleteWithError("Maximum loops achieved without completing goal");
                return;
            }
            this.setMDC();
            this.iteration = iteration;
            this.llmSession.addIteration(iteration);
            ConcurrencyUtils.completeAsyncInTCCL(this.llmSession.getNext(), (Executor)this.schedulerService.ioScheduler(), (either, llmException) -> {
                this.setMDC();
                if (llmException != null) {
                    llmException = ExceptionUtils.unwrap(llmException);
                    SecureLogger.secureLogger().debug("LLM call failed", llmException);
                    this.updateStateAndCompleteWithError("Cannot complete task due to issue accessing reasoning engine");
                    return;
                }
                either.apply(this::onLLMOutput, this::onToolSelection);
            });
        }
        catch (Exception e) {
            this.completionCallback.error((Throwable)e);
        }
    }

    private void onToolSelection(ToolSelection selection) {
        this.iteration.setToolSelection(selection);
        String toolToCall = selection.getToolId();
        String toolInput = selection.getInput();
        this.reasoningLog("LLM selected tool " + toolToCall, "Tool Input: " + toolInput);
        Tool toolExecutor = this.toolExecutors.get(toolToCall);
        if (toolExecutor == null) {
            ToolResponse errorResponse = new ToolResponse("Selected function '" + toolToCall + "' doesn't exist", selection, null);
            this.iteration.setToolResponse(errorResponse);
            this.llmSession.addToolResponse(errorResponse);
            this.reasoningLog(errorResponse.getResult(), new String[0]);
            if (this.maxConsecutiveErrorsReached()) {
                this.updateStateAndCompleteWithError("I'm sorry, but I was unable to determine next step. Task cannot be completed");
            } else {
                this.nextLoop();
            }
            return;
        }
        ToolRequest toolRequest = new ToolRequest(selection, this.taskContext, this.apiInstanceId);
        this.executeTool(toolExecutor, toolRequest, toolToCall, toolResponse -> {
            try {
                this.resetConsecutiveErrors();
                if (this.handleToolResponse(toolToCall, (ToolResponse)toolResponse)) {
                    this.iteration.setToolResponse((ToolResponse)toolResponse);
                    this.llmSession.addToolResponse((ToolResponse)toolResponse);
                    this.nextLoop();
                }
            }
            catch (Exception e) {
                LOGGER.error("Error processing response for tool {}", (Object)toolToCall, (Object)e);
                this.completeAsInternalError();
            }
        });
    }

    private void onLLMOutput(LLMOutput output) {
        this.iteration.setLlmOutput(output);
        if (output.isGoalComplete()) {
            this.taskContext.setConversationState(new ConversationState(TaskState.COMPLETED));
            this.safeUpsertTaskContext();
            this.completionCallback.success(this.a2aService.completedJsonResult(this.conversation, this.taskContext, output.getResult()));
        } else if (output.isAdditionalInputRequired()) {
            try {
                this.updateTaskContextForHITLByLLM();
                this.completionCallback.success(this.a2aService.inputRequiredJsonResult(this.conversation, this.taskContext, output.getResult()));
            }
            catch (Exception e) {
                LOGGER.error("Could not update context when requesting additional data from client", (Throwable)e);
                this.completeAsInternalError();
            }
        } else if (output.isGoalFailed()) {
            this.updateStateAndCompleteWithError(output.getResult());
        } else {
            String errorMessage = "LLM output is logically invalid. Goal is neither completed, failed or requesting additional input";
            this.reasoningLog(errorMessage, new String[0]);
            if (this.maxConsecutiveErrorsReached()) {
                this.completeAsInternalError();
            } else {
                Iteration next = this.taskContext.newIteration();
                next.setUserPrompt(errorMessage + ". Re-evaluate, provide a valid response or select a tool");
                this.nextLoop(next);
            }
        }
    }

    private void completeAsInternalError() {
        this.updateStateAndCompleteWithError("Could not complete task due to an internal error");
    }

    private <T extends ToolResponse> void executeTool(Tool executor, ToolRequest toolRequest, String toolName, Consumer<T> onSuccess) {
        ConcurrencyUtils.completeAsyncInTCCL(executor.execute(toolRequest, this.extensionsClient), (Executor)this.schedulerService.ioScheduler(), (toolResponse, toolException) -> {
            this.setMDC();
            if (toolException != null) {
                toolException = ExceptionUtils.unwrap(toolException);
                String message = "Execution of tool " + toolName + " failed.";
                String secureAppendix = "Error message was: " + toolException.getMessage();
                this.reasoningLog(message, secureAppendix);
                ToolResponse failResponse = new ToolResponse(message + " " + secureAppendix, toolRequest.selection(), executor.getToolType());
                this.iteration.setToolResponse(failResponse);
                this.safeUpsertTaskContext();
                this.llmSession.addToolResponse(failResponse);
                if (this.maxConsecutiveErrorsReached()) {
                    this.updateStateAndCompleteWithError("Could not complete next step due to an error. Task failed.");
                } else {
                    this.nextLoop();
                }
            } else {
                onSuccess.accept(toolResponse);
            }
        });
    }

    private boolean handleToolResponse(String toolName, ToolResponse toolResponse) {
        this.reasoningLog("Executed tool " + toolName, "Output was: " + toolResponse.getResult());
        if (toolResponse.getToolType() == ToolType.A2A) {
            return this.handleA2AToolResponse(toolName, (A2AToolResponse)toolResponse);
        }
        return true;
    }

    private void updateStateAndCompleteWithError(String message) {
        this.reasoningLog("Task failed: " + message, new String[0]);
        try {
            this.taskContext.setConversationState(new ConversationState(TaskState.FAILED, message));
            this.safeUpsertTaskContext();
        }
        finally {
            try {
                this.completionCallback.success(this.a2aService.asJsonResult(new Task.Builder().id(this.taskContext.getTaskId()).contextId(this.conversation.getId()).status(new TaskStatus(TaskState.FAILED, new Message.Builder().messageId(UUID.randomUUID().toString()).taskId(this.taskContext.getTaskId()).role(Message.Role.AGENT).parts(List.of(new TextPart(message))).build(), OffsetDateTime.now())).build()));
            }
            catch (Exception e) {
                this.completionCallback.error((Throwable)e);
            }
        }
    }

    private boolean handleA2AToolResponse(String toolName, A2AToolResponse toolResponse) {
        boolean inputRequired = toolResponse.isInputRequired();
        String toolOutput = toolResponse.getResult();
        AgentToolContext toolContext = this.taskContext.getAgentToolContext(toolName);
        toolContext.setTaskState(toolResponse.getState());
        if (toolResponse.getContextId() != null) {
            toolContext.setContextId(toolResponse.getContextId());
        }
        if (toolResponse.getTaskId() != null) {
            toolContext.setTaskId(toolResponse.getTaskId());
        }
        if (inputRequired) {
            this.taskContext.setConversationState(new ConversationState(TaskState.INPUT_REQUIRED, new AdditionalInputRequired(AdditionalInputRequired.RequesterType.A2A_TOOL, this.iteration.getNumber(), toolName)));
            this.safeUpsertTaskContext();
            this.completionCallback.success(this.a2aService.inputRequiredJsonResult(this.conversation, this.taskContext, toolOutput));
            return false;
        }
        return true;
    }

    private void reasoningLog(String message, String ... secureAppendix) {
        LOGGER.info(message);
        if (secureAppendix != null) {
            for (String s : secureAppendix) {
                SecureLogger.secureLogger().debug(s);
            }
        }
    }

    private CheckedRunnable resumeHITLByLLM(AdditionalInputRequired additionalInputRequired) {
        return () -> {
            LOGGER.info("Resuming conversation {} from iteration {}", (Object)this.conversation.getId(), (Object)additionalInputRequired.getLoopNumber());
            Iteration iteration = this.taskContext.newIteration();
            iteration.setUserPrompt(this.llmSession.getInitialRequest().getPrompt());
            this.nextLoop(iteration);
        };
    }

    private CheckedRunnable resumeHITLFromA2ATool() {
        Iteration lastIteration = this.taskContext.getLastIteration();
        String waitingToolName = lastIteration.getToolSelection().getToolId();
        Tool tool = this.toolExecutors.get(waitingToolName);
        if (tool == null) {
            throw new ModuleException("Could not find A2A agent config to resume the tool - " + waitingToolName, (ErrorTypeDefinition)BrokerErrorTypes.A2A);
        }
        ToolSelection selection = lastIteration.getToolSelection();
        selection.setInput(this.llmSession.getInitialRequest().getPrompt());
        ToolRequest toolRequest = new ToolRequest(selection, this.taskContext, this.apiInstanceId);
        return () -> this.executeTool(tool, toolRequest, waitingToolName, toolResponse -> {
            try {
                if (this.handleA2AToolResponse(waitingToolName, (A2AToolResponse)toolResponse)) {
                    lastIteration.setToolResponse((ToolResponse)toolResponse);
                    this.safeUpsertTaskContext();
                    this.llmSession.addToolResponse((ToolResponse)toolResponse);
                    this.nextLoop();
                }
            }
            catch (Exception e) {
                this.completionCallback.error((Throwable)e);
            }
        });
    }

    private void updateTaskContextForHITLByLLM() {
        this.taskContext.setConversationState(new ConversationState(TaskState.INPUT_REQUIRED, new AdditionalInputRequired(AdditionalInputRequired.RequesterType.LLM, this.currentLoop)));
        this.safeUpsertTaskContext();
    }

    private void safeUpsertTaskContext() {
        try {
            this.conversationService.upsert(this.taskContext, this.broker);
        }
        catch (Exception e) {
            LOGGER.error("Exception found updating TaskContext", (Throwable)e);
        }
    }

    private void resetConsecutiveErrors() {
        this.consecutiveErrors = 0;
    }

    private boolean maxConsecutiveErrorsReached() {
        if (++this.consecutiveErrors >= this.maxConsecutiveErrors) {
            this.reasoningLog("Maximum consecutive errors reached. Task '%s' will fail".formatted(this.taskContext.getTaskId()), new String[0]);
            return true;
        }
        return false;
    }

    private void setMDC() {
        MDC.put((String)AGENT_MDC_KEY, (String)this.broker.getConfigName());
        MDC.put((String)ITERATION_MDC_KEY, (String)String.valueOf(this.currentLoop));
        MDC.put((String)TASK_ID_MDC_KEY, (String)this.taskId);
        MDC.put((String)CONTEXT_ID_MDC_KEY, (String)this.contextId);
        MDC.put((String)API_INSTANCE_ID_MDC_KEY, (String)this.apiInstanceId);
    }

    private class CompletionCallbackDecorator<T, A>
    implements CompletionCallback<T, A> {
        private final CompletionCallback<T, A> delegate;
        private final Map<String, String> originalMdcContext;

        public CompletionCallbackDecorator(CompletionCallback<T, A> delegate) {
            this.delegate = delegate;
            this.originalMdcContext = MDC.getCopyOfContextMap();
        }

        public void success(Result<T, A> result) {
            try {
                this.delegate.success(result);
            }
            finally {
                this.restoreMdcContext();
            }
        }

        public void error(Throwable throwable) {
            try {
                this.delegate.error(throwable);
            }
            finally {
                this.restoreMdcContext();
            }
        }

        private void restoreMdcContext() {
            MDC.setContextMap(this.originalMdcContext);
        }
    }
}

