/*
 * Decompiled with CFR 0.152.
 */
package com.mulesoft.modules.agent.broker.internal.tool.a2a;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.github.benmanes.caffeine.cache.Cache;
import com.mulesoft.modules.agent.broker.api.model.a2a.A2AClient;
import com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes;
import com.mulesoft.modules.agent.broker.internal.log.SecureLogger;
import com.mulesoft.modules.agent.broker.internal.serializer.DwConverter;
import com.mulesoft.modules.agent.broker.internal.state.model.AgentToolContext;
import com.mulesoft.modules.agent.broker.internal.state.model.Conversation;
import com.mulesoft.modules.agent.broker.internal.state.model.TaskContext;
import com.mulesoft.modules.agent.broker.internal.tool.Tool;
import com.mulesoft.modules.agent.broker.internal.tool.ToolRequest;
import com.mulesoft.modules.agent.broker.internal.tool.ToolResponse;
import com.mulesoft.modules.agent.broker.internal.tool.ToolType;
import com.mulesoft.modules.agent.broker.internal.tool.ToolVisitor;
import com.mulesoft.modules.agent.broker.internal.tool.a2a.A2AToolResponse;
import com.mulesoft.modules.agent.broker.internal.tool.a2a.A2AUtils;
import com.mulesoft.modules.agent.broker.internal.tool.a2a.AgentSummary;
import com.mulesoft.modules.agent.broker.internal.util.ToolUtils;
import io.a2a.spec.Artifact;
import io.a2a.spec.EventKind;
import io.a2a.spec.Message;
import io.a2a.spec.MessageSendParams;
import io.a2a.spec.Part;
import io.a2a.spec.Task;
import io.a2a.spec.TaskState;
import io.a2a.spec.TaskStatus;
import io.a2a.spec.TextPart;
import io.a2a.util.Utils;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.time.OffsetDateTime;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import javax.inject.Inject;
import org.mule.runtime.api.lifecycle.Initialisable;
import org.mule.runtime.api.lifecycle.InitialisationException;
import org.mule.runtime.api.metadata.DataType;
import org.mule.runtime.api.metadata.MediaType;
import org.mule.runtime.api.metadata.TypedValue;
import org.mule.runtime.api.scheduler.SchedulerService;
import org.mule.runtime.core.api.el.ExpressionManager;
import org.mule.runtime.core.api.util.IOUtils;
import org.mule.runtime.extension.api.client.ExtensionsClient;
import org.mule.runtime.extension.api.client.OperationParameterizer;
import org.mule.runtime.extension.api.error.ErrorTypeDefinition;
import org.mule.runtime.extension.api.exception.ModuleException;
import org.mule.runtime.extension.api.runtime.operation.Result;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class A2AService
implements Initialisable {
    private static final Logger LOGGER = LoggerFactory.getLogger(A2AService.class);
    private static final String A2A = "A2A";
    private static final DataType AGENT_SUMMARY_DATA_TYPE = DataType.fromType(AgentSummary.class);
    private final Cache<String, Tool> tools = ToolUtils.newToolCache();
    @Inject
    private SchedulerService schedulerService;
    @Inject
    private ExpressionManager expressionManager;
    private DwConverter agentSummaryWriter;

    public void initialise() throws InitialisationException {
        this.agentSummaryWriter = new DwConverter(this.expressionManager, "%dw 2.0\noutput application/java\nvar parsedPayload = read(payload, \"application/json\")\nvar baseDescription = parsedPayload.description default \"\"\nvar skills = parsedPayload.skills default []\nvar skillsSummary = \"\n\nAvailable skills:\n\" ++ ((skills map (\"\u00e2\u0080\u00a2 \" ++ $.name ++ \": \" ++ $.description)) joinBy \"\n\")\nvar inputModes = parsedPayload.defaultInputModes default []\nvar outputModes = parsedPayload.defaultOutputModes default []\nvar inputDescription = \"Accepts input in formats: \" ++ (inputModes joinBy \", \")\nvar outputDescription = \"Returns output in formats: \" ++ (outputModes joinBy \", \")\n---\n{\n    name: parsedPayload.name,\n    description: baseDescription ++ skillsSummary,\n    inputDescription: inputDescription,\n    outputDescription: outputDescription\n}\n", (value, builder) -> builder.addBinding("payload", new TypedValue(value, DataType.STRING)));
    }

    public CompletableFuture<Map<String, Tool>> getTools(List<A2AClient> a2aClients, ExtensionsClient extensionsClient) {
        if (null == a2aClients || a2aClients.isEmpty()) {
            return CompletableFuture.completedFuture(Map.of());
        }
        ConcurrentHashMap<String, Tool> collectedTools = new ConcurrentHashMap<String, Tool>();
        AtomicInteger countDown = new AtomicInteger(a2aClients.size());
        CompletableFuture<Map<String, Tool>> future = new CompletableFuture<Map<String, Tool>>();
        try {
            for (A2AClient a2aClient : a2aClients) {
                String a2aConfigRef = a2aClient.getA2AClientConfigRef();
                Tool tool = (Tool)this.tools.getIfPresent((Object)a2aConfigRef);
                if (tool != null) {
                    this.collect(tool, collectedTools, countDown, future);
                    continue;
                }
                this.schedulerService.ioScheduler().submit(() -> {
                    try {
                        this.collect((Tool)this.tools.get((Object)a2aConfigRef, key -> {
                            AgentSummary agentSummary = this.fetchAgentSummary((String)key, countDown, future, extensionsClient);
                            return new A2ATool(ToolUtils.generateToolId(a2aConfigRef, agentSummary.getName()), agentSummary.getName(), agentSummary.getDescription(), agentSummary.getInputDescription(), agentSummary.getOutputDescription(), a2aConfigRef);
                        }), collectedTools, countDown, future);
                    }
                    catch (Exception e) {
                        this.handleDiscoveryException(countDown, future, e, a2aConfigRef);
                    }
                });
            }
        }
        catch (Exception e) {
            future.completeExceptionally(e);
        }
        return future;
    }

    public Result<String, Void> completedJsonResult(Conversation conversation, TaskContext taskContext, String output) {
        Task.Builder builder = new Task.Builder().id(taskContext.getTaskId()).contextId(conversation.getId()).status(new TaskStatus(TaskState.COMPLETED)).artifacts(List.of(new Artifact.Builder().artifactId(UUID.randomUUID().toString()).parts(List.of(new TextPart(output))).build()));
        return this.asJsonResult(builder.build());
    }

    public Result<String, Void> inputRequiredJsonResult(Conversation conversation, TaskContext taskContext, String message) {
        Task.Builder builder = new Task.Builder().id(taskContext.getTaskId()).contextId(conversation.getId()).status(new TaskStatus(TaskState.INPUT_REQUIRED, new Message.Builder().role(Message.Role.AGENT).parts(new Part[]{new TextPart(message)}).taskId(taskContext.getTaskId()).messageId(UUID.randomUUID().toString()).build(), OffsetDateTime.now()));
        return this.asJsonResult(builder.build());
    }

    public Result<String, Void> asJsonResult(Task task) {
        return Result.builder().output((Object)this.toJson(task)).mediaType(MediaType.APPLICATION_JSON).build();
    }

    public <T> T unmarshall(InputStream json, Class<T> type) {
        String consumedPayload = null;
        try {
            consumedPayload = IOUtils.toString((InputStream)json);
            return (T)Utils.OBJECT_MAPPER.readValue(consumedPayload, type);
        }
        catch (JsonProcessingException e) {
            SecureLogger.secureLogger().debug("Received A2A payload does not comply with the schema. Payload was:\n{}", (Object)consumedPayload, (Object)e);
            throw new ModuleException("Received A2A payload does not comply with the schema", (ErrorTypeDefinition)BrokerErrorTypes.A2A, (Throwable)e);
        }
        catch (Exception e) {
            throw new ModuleException("Error unmarshalling A2A payload", (ErrorTypeDefinition)BrokerErrorTypes.A2A, (Throwable)e);
        }
    }

    private AgentSummary fetchAgentSummary(String a2aConfigRef, AtomicInteger countDown, CompletableFuture<Map<String, Tool>> toolHandlersFuture, ExtensionsClient extensionsClient) {
        LOGGER.debug("Cache miss for config: {}. Attempting to fetch agent card...", (Object)a2aConfigRef);
        try {
            Result result = (Result)extensionsClient.execute(A2A, "getCard", params -> params.withConfigRef(a2aConfigRef)).get();
            Object output = result.getOutput();
            String cardJson = (String)output;
            AgentSummary parsedAgentData = (AgentSummary)this.agentSummaryWriter.evaluate(new TypedValue((Object)cardJson, DataType.STRING), AGENT_SUMMARY_DATA_TYPE);
            LOGGER.debug("Successfully parsed AgentSummary for config {}: name={}, description={}", new Object[]{a2aConfigRef, parsedAgentData.getName(), parsedAgentData.getDescription()});
            return parsedAgentData;
        }
        catch (ExecutionException e) {
            this.handleDiscoveryException(countDown, toolHandlersFuture, (Throwable)new ModuleException("Failed to fetch agent card for config '%s'".formatted(a2aConfigRef), (ErrorTypeDefinition)BrokerErrorTypes.TOOL_ERROR, e.getCause()), a2aConfigRef);
            return null;
        }
        catch (InterruptedException e) {
            this.handleDiscoveryException(countDown, toolHandlersFuture, (Throwable)new ModuleException("Fetch agent card for config %s was interrupted.".formatted(a2aConfigRef), (ErrorTypeDefinition)BrokerErrorTypes.TOOL_ERROR, e.getCause()), a2aConfigRef);
            return null;
        }
    }

    private void collect(Tool tool, Map<String, Tool> tools, AtomicInteger countDown, CompletableFuture<Map<String, Tool>> future) {
        tools.put(tool.getId(), tool);
        if (countDown.decrementAndGet() <= 0) {
            future.complete(tools);
        }
    }

    private void handleDiscoveryException(AtomicInteger countDown, CompletableFuture<?> toolHandlersFuture, Throwable t, String a2aConfigRef) {
        toolHandlersFuture.completeExceptionally((Throwable)new ModuleException("Exception discovering agent from A2A Client config %s: %s".formatted(a2aConfigRef, t.getMessage()), (ErrorTypeDefinition)BrokerErrorTypes.TOOL_ERROR, t));
        countDown.set(-1);
    }

    private String toA2ARequest(ToolRequest request) {
        AgentToolContext toolContext = request.taskContext().getAgentToolContext(request.selection().getToolId());
        Message.Builder builder = new Message.Builder().messageId(UUID.randomUUID().toString()).role(Message.Role.USER).parts(new Part[]{new TextPart(request.selection().getInput())});
        builder.contextId(toolContext.getContextId());
        if (toolContext.getTaskId() != null) {
            if (toolContext.getTaskState().isFinal()) {
                builder.referenceTaskIds(List.of(toolContext.getTaskId()));
            } else {
                builder.taskId(toolContext.getTaskId());
            }
        }
        return this.toJson(new MessageSendParams.Builder().message(builder.build()).build());
    }

    private String toJson(Object value) {
        try {
            return Utils.OBJECT_MAPPER.writeValueAsString(value);
        }
        catch (Exception e) {
            throw new ModuleException("Error marshalling A2A payload", (ErrorTypeDefinition)BrokerErrorTypes.A2A, (Throwable)e);
        }
    }

    public class A2ATool
    extends Tool {
        private final String configRef;

        private A2ATool(String id, String name, String description, String input, String output, String configRef) {
            super(id, name, description, input, output);
            this.configRef = configRef;
        }

        @Override
        public CompletableFuture<ToolResponse> execute(ToolRequest request, ExtensionsClient extensionsClient) {
            try {
                String a2AMessage = A2AService.this.toA2ARequest(request);
                SecureLogger.secureLogger().debug("Sending A2A message to agent {}:\n {}", (Object)request.selection().getToolId(), (Object)a2AMessage);
                return extensionsClient.execute(A2AService.A2A, "sendMessage", params -> ((OperationParameterizer)((OperationParameterizer)params.withConfigRef(this.configRef)).withParameter("message", (Object)new ByteArrayInputStream(a2AMessage.getBytes()))).withParameter("Request", "additionalProperties", this.collectHeaders(request))).thenApply(result -> this.parseResponse(request, result.getOutput()));
            }
            catch (Exception e) {
                return CompletableFuture.failedFuture((Throwable)new ModuleException("Failed to invoke Agent " + request.selection().getToolId(), (ErrorTypeDefinition)BrokerErrorTypes.TOOL_ERROR, (Throwable)e));
            }
        }

        private ToolResponse parseResponse(ToolRequest request, Object rawResponse) {
            if (rawResponse == null) {
                throw new ModuleException("Tool did not return any response", (ErrorTypeDefinition)BrokerErrorTypes.TOOL_ERROR);
            }
            if (!(rawResponse instanceof String)) {
                throw new ModuleException("Tool response is of unexpected type: " + rawResponse.getClass().getName(), (ErrorTypeDefinition)BrokerErrorTypes.TOOL_ERROR);
            }
            try {
                EventKind response = (EventKind)Utils.OBJECT_MAPPER.readValue((String)rawResponse, EventKind.class);
                if (response instanceof Task) {
                    Task task = (Task)response;
                    A2AToolOutput output = new A2AToolOutput(task.getStatus() != null ? A2AUtils.collectParts(task.getStatus().message()) : null, task.getArtifacts());
                    return new A2AToolResponse(this.toJson(output), request.selection(), task.getId(), task.getContextId(), task.getStatus().state());
                }
                if (response instanceof Message) {
                    Message message = (Message)response;
                    A2AToolOutput output = new A2AToolOutput(A2AUtils.collectParts(message), null);
                    return new A2AToolResponse(this.toJson(output), request.selection(), message.getTaskId(), message.getContextId(), TaskState.COMPLETED);
                }
                throw new ModuleException("Tool response is of unexpected type: " + response.getKind(), (ErrorTypeDefinition)BrokerErrorTypes.TOOL_ERROR);
            }
            catch (Exception e) {
                throw new ModuleException("Failed to parse tool response: " + e.getMessage(), (ErrorTypeDefinition)BrokerErrorTypes.TOOL_ERROR, (Throwable)e);
            }
        }

        private String toJson(A2AToolOutput output) {
            try {
                return Utils.OBJECT_MAPPER.writeValueAsString((Object)output);
            }
            catch (Exception e) {
                throw new ModuleException("Failed to serialize tool response", (ErrorTypeDefinition)BrokerErrorTypes.TOOL_ERROR, (Throwable)e);
            }
        }

        @Override
        public void accept(ToolVisitor visitor) {
            visitor.visit(this);
        }

        @Override
        public ToolType getToolType() {
            return ToolType.A2A;
        }

        @JsonInclude(value=JsonInclude.Include.NON_ABSENT)
        @JsonIgnoreProperties(ignoreUnknown=true)
        public record A2AToolOutput(List<Part<?>> message, List<Artifact> artifacts) {
        }
    }
}

