/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.mcp.client;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.mcp.client.McpClient;
import dev.langchain4j.mcp.client.McpGetPromptResult;
import dev.langchain4j.mcp.client.McpPrompt;
import dev.langchain4j.mcp.client.McpReadResourceResult;
import dev.langchain4j.mcp.client.McpResource;
import dev.langchain4j.mcp.client.McpResourceTemplate;
import dev.langchain4j.mcp.client.PromptsHelper;
import dev.langchain4j.mcp.client.ResourcesHelper;
import dev.langchain4j.mcp.client.ToolExecutionHelper;
import dev.langchain4j.mcp.client.ToolSpecificationHelper;
import dev.langchain4j.mcp.client.logging.DefaultMcpLogMessageHandler;
import dev.langchain4j.mcp.client.logging.McpLogMessageHandler;
import dev.langchain4j.mcp.client.protocol.CancellationNotification;
import dev.langchain4j.mcp.client.protocol.InitializeParams;
import dev.langchain4j.mcp.client.protocol.McpCallToolRequest;
import dev.langchain4j.mcp.client.protocol.McpGetPromptRequest;
import dev.langchain4j.mcp.client.protocol.McpInitializeRequest;
import dev.langchain4j.mcp.client.protocol.McpListPromptsRequest;
import dev.langchain4j.mcp.client.protocol.McpListResourceTemplatesRequest;
import dev.langchain4j.mcp.client.protocol.McpListResourcesRequest;
import dev.langchain4j.mcp.client.protocol.McpListToolsRequest;
import dev.langchain4j.mcp.client.protocol.McpPingRequest;
import dev.langchain4j.mcp.client.protocol.McpReadResourceRequest;
import dev.langchain4j.mcp.client.transport.McpOperationHandler;
import dev.langchain4j.mcp.client.transport.McpTransport;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DefaultMcpClient
implements McpClient {
    private static final Logger log = LoggerFactory.getLogger(DefaultMcpClient.class);
    private final AtomicLong idGenerator = new AtomicLong(0L);
    private final McpTransport transport;
    static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    private final String key;
    private final String clientName;
    private final String clientVersion;
    private final String protocolVersion;
    private final Duration initializationTimeout;
    private final Duration toolExecutionTimeout;
    private final Duration resourcesTimeout;
    private final Duration promptsTimeout;
    private final Duration pingTimeout;
    private final JsonNode RESULT_TIMEOUT;
    private final String toolExecutionTimeoutErrorMessage;
    private final Map<Long, CompletableFuture<JsonNode>> pendingOperations = new ConcurrentHashMap<Long, CompletableFuture<JsonNode>>();
    private final McpOperationHandler messageHandler;
    private final McpLogMessageHandler logHandler;
    private final AtomicReference<List<McpResource>> resourceRefs = new AtomicReference();
    private final AtomicReference<List<McpResourceTemplate>> resourceTemplateRefs = new AtomicReference();
    private final AtomicReference<List<McpPrompt>> promptRefs = new AtomicReference();
    private final AtomicReference<List<ToolSpecification>> toolListRefs = new AtomicReference();
    private final AtomicBoolean toolListOutOfDate = new AtomicBoolean(true);
    private final AtomicReference<CompletableFuture<Void>> toolListUpdateInProgress = new AtomicReference<Object>(null);
    private final Duration reconnectInterval;

    public DefaultMcpClient(Builder builder) {
        this.transport = (McpTransport)ValidationUtils.ensureNotNull((Object)builder.transport, (String)"transport");
        this.key = (String)Utils.getOrDefault((Object)builder.key, () -> UUID.randomUUID().toString());
        this.clientName = (String)Utils.getOrDefault((Object)builder.clientName, (Object)"langchain4j");
        this.clientVersion = (String)Utils.getOrDefault((Object)builder.clientVersion, (Object)"1.0");
        this.protocolVersion = (String)Utils.getOrDefault((Object)builder.protocolVersion, (Object)"2024-11-05");
        this.initializationTimeout = (Duration)Utils.getOrDefault((Object)builder.initializationTimeout, (Object)Duration.ofSeconds(30L));
        this.toolExecutionTimeout = (Duration)Utils.getOrDefault((Object)builder.toolExecutionTimeout, (Object)Duration.ofSeconds(60L));
        this.resourcesTimeout = (Duration)Utils.getOrDefault((Object)builder.resourcesTimeout, (Object)Duration.ofSeconds(60L));
        this.promptsTimeout = (Duration)Utils.getOrDefault((Object)builder.promptsTimeout, (Object)Duration.ofSeconds(60L));
        this.logHandler = (McpLogMessageHandler)Utils.getOrDefault((Object)builder.logHandler, (Object)new DefaultMcpLogMessageHandler());
        this.pingTimeout = (Duration)Utils.getOrDefault((Object)builder.pingTimeout, (Object)Duration.ofSeconds(10L));
        this.reconnectInterval = (Duration)Utils.getOrDefault((Object)builder.reconnectInterval, (Object)Duration.ofSeconds(5L));
        this.toolExecutionTimeoutErrorMessage = (String)Utils.getOrDefault((Object)builder.toolExecutionTimeoutErrorMessage, (Object)"There was a timeout executing the tool");
        this.RESULT_TIMEOUT = JsonNodeFactory.instance.objectNode();
        this.messageHandler = new McpOperationHandler(this.pendingOperations, this.transport, this.logHandler::handleLogMessage, () -> this.toolListOutOfDate.set(true));
        ((ObjectNode)this.RESULT_TIMEOUT).putObject("result").putArray("content").addObject().put("type", "text").put("text", this.toolExecutionTimeoutErrorMessage);
        this.transport.onFailure(() -> {
            try {
                TimeUnit.MILLISECONDS.sleep(this.reconnectInterval.toMillis());
            }
            catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
            log.info("Trying to reconnect...");
            this.initialize();
        });
        this.initialize();
    }

    private void initialize() {
        this.transport.start(this.messageHandler);
        long operationId = this.idGenerator.getAndIncrement();
        McpInitializeRequest request = new McpInitializeRequest(operationId);
        InitializeParams params = this.createInitializeParams();
        request.setParams(params);
        try {
            JsonNode capabilities = this.transport.initialize(request).get(this.initializationTimeout.toMillis(), TimeUnit.MILLISECONDS);
            log.debug("MCP server capabilities: {}", (Object)capabilities.get("result"));
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        finally {
            this.pendingOperations.remove(operationId);
        }
    }

    private InitializeParams createInitializeParams() {
        InitializeParams params = new InitializeParams();
        params.setProtocolVersion(this.protocolVersion);
        InitializeParams.ClientInfo clientInfo = new InitializeParams.ClientInfo();
        clientInfo.setName(this.clientName);
        clientInfo.setVersion(this.clientVersion);
        params.setClientInfo(clientInfo);
        InitializeParams.Capabilities capabilities = new InitializeParams.Capabilities();
        InitializeParams.Capabilities.Roots roots = new InitializeParams.Capabilities.Roots();
        roots.setListChanged(false);
        capabilities.setRoots(roots);
        params.setCapabilities(capabilities);
        return params;
    }

    @Override
    public String key() {
        return this.key;
    }

    @Override
    public List<ToolSpecification> listTools() {
        if (this.toolListOutOfDate.get()) {
            CompletableFuture<Void> updateInProgress = this.toolListUpdateInProgress.get();
            if (updateInProgress != null) {
                this.toolListUpdateInProgress.get();
                return this.toolListRefs.get();
            }
            CompletableFuture<Object> update = new CompletableFuture<Object>();
            this.toolListUpdateInProgress.set(update);
            try {
                this.obtainToolList();
            }
            finally {
                update.complete(null);
                this.toolListOutOfDate.set(false);
                this.toolListUpdateInProgress.set(null);
            }
            return this.toolListRefs.get();
        }
        return this.toolListRefs.get();
    }

    @Override
    public String executeTool(ToolExecutionRequest executionRequest) {
        ObjectNode arguments = null;
        try {
            String args = executionRequest.arguments();
            if (Utils.isNullOrBlank((String)args)) {
                args = "{}";
            }
            arguments = (ObjectNode)OBJECT_MAPPER.readValue(args, ObjectNode.class);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
        long operationId = this.idGenerator.getAndIncrement();
        McpCallToolRequest operation = new McpCallToolRequest(operationId, executionRequest.name(), arguments);
        long timeoutMillis = this.toolExecutionTimeout.toMillis() == 0L ? Integer.MAX_VALUE : this.toolExecutionTimeout.toMillis();
        CompletableFuture<JsonNode> resultFuture = null;
        JsonNode result = null;
        try {
            resultFuture = this.transport.executeOperationWithResponse(operation);
            result = resultFuture.get(timeoutMillis, TimeUnit.MILLISECONDS);
        }
        catch (TimeoutException timeout) {
            this.transport.executeOperationWithoutResponse(new CancellationNotification(operationId, "Timeout"));
            String string = ToolExecutionHelper.extractResult(this.RESULT_TIMEOUT);
            return string;
        }
        catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException(e);
        }
        finally {
            this.pendingOperations.remove(operationId);
        }
        return ToolExecutionHelper.extractResult(result);
    }

    @Override
    public List<McpResource> listResources() {
        if (this.resourceRefs.get() == null) {
            this.obtainResourceList();
        }
        return this.resourceRefs.get();
    }

    @Override
    public McpReadResourceResult readResource(String uri) {
        long operationId = this.idGenerator.getAndIncrement();
        McpReadResourceRequest operation = new McpReadResourceRequest(operationId, uri);
        long timeoutMillis = this.resourcesTimeout.toMillis() == 0L ? Integer.MAX_VALUE : this.resourcesTimeout.toMillis();
        JsonNode result = null;
        CompletableFuture<JsonNode> resultFuture = null;
        try {
            resultFuture = this.transport.executeOperationWithResponse(operation);
            result = resultFuture.get(timeoutMillis, TimeUnit.MILLISECONDS);
            McpReadResourceResult mcpReadResourceResult = ResourcesHelper.parseResourceContents(result);
            return mcpReadResourceResult;
        }
        catch (InterruptedException | ExecutionException | TimeoutException e) {
            throw new RuntimeException(e);
        }
        finally {
            this.pendingOperations.remove(operationId);
        }
    }

    @Override
    public List<McpPrompt> listPrompts() {
        if (this.promptRefs.get() == null) {
            this.obtainPromptList();
        }
        return this.promptRefs.get();
    }

    @Override
    public McpGetPromptResult getPrompt(String name, Map<String, Object> arguments) {
        long operationId = this.idGenerator.getAndIncrement();
        McpGetPromptRequest operation = new McpGetPromptRequest(operationId, name, arguments);
        long timeoutMillis = this.promptsTimeout.toMillis() == 0L ? Integer.MAX_VALUE : this.promptsTimeout.toMillis();
        JsonNode result = null;
        CompletableFuture<JsonNode> resultFuture = null;
        try {
            resultFuture = this.transport.executeOperationWithResponse(operation);
            result = resultFuture.get(timeoutMillis, TimeUnit.MILLISECONDS);
            McpGetPromptResult mcpGetPromptResult = PromptsHelper.parsePromptContents(result);
            return mcpGetPromptResult;
        }
        catch (InterruptedException | ExecutionException | TimeoutException e) {
            throw new RuntimeException(e);
        }
        finally {
            this.pendingOperations.remove(operationId);
        }
    }

    @Override
    public void checkHealth() {
        this.transport.checkHealth();
        long operationId = this.idGenerator.getAndIncrement();
        McpPingRequest ping = new McpPingRequest(operationId);
        try {
            CompletableFuture<JsonNode> resultFuture = this.transport.executeOperationWithResponse(ping);
            resultFuture.get(this.pingTimeout.toMillis(), TimeUnit.MILLISECONDS);
        }
        catch (InterruptedException | ExecutionException | TimeoutException e) {
            throw new RuntimeException(e);
        }
        finally {
            this.pendingOperations.remove(operationId);
        }
    }

    @Override
    public List<McpResourceTemplate> listResourceTemplates() {
        if (this.resourceTemplateRefs.get() == null) {
            this.obtainResourceTemplateList();
        }
        return this.resourceTemplateRefs.get();
    }

    private synchronized void obtainToolList() {
        McpListToolsRequest operation = new McpListToolsRequest(this.idGenerator.getAndIncrement());
        CompletableFuture<JsonNode> resultFuture = this.transport.executeOperationWithResponse(operation);
        JsonNode result = null;
        try {
            result = resultFuture.get();
        }
        catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException(e);
        }
        finally {
            this.pendingOperations.remove(operation.getId());
        }
        List<ToolSpecification> toolList = ToolSpecificationHelper.toolSpecificationListFromMcpResponse((ArrayNode)result.get("result").get("tools"));
        this.toolListRefs.set(toolList);
    }

    private synchronized void obtainResourceList() {
        if (this.resourceRefs.get() != null) {
            return;
        }
        McpListResourcesRequest operation = new McpListResourcesRequest(this.idGenerator.getAndIncrement());
        long timeoutMillis = this.resourcesTimeout.toMillis() == 0L ? Integer.MAX_VALUE : this.resourcesTimeout.toMillis();
        JsonNode result = null;
        CompletableFuture<JsonNode> resultFuture = null;
        try {
            resultFuture = this.transport.executeOperationWithResponse(operation);
            result = resultFuture.get(timeoutMillis, TimeUnit.MILLISECONDS);
            this.resourceRefs.set(ResourcesHelper.parseResourceRefs(result));
        }
        catch (InterruptedException | ExecutionException | TimeoutException e) {
            throw new RuntimeException(e);
        }
        finally {
            this.pendingOperations.remove(operation.getId());
        }
    }

    private synchronized void obtainResourceTemplateList() {
        if (this.resourceTemplateRefs.get() != null) {
            return;
        }
        McpListResourceTemplatesRequest operation = new McpListResourceTemplatesRequest(this.idGenerator.getAndIncrement());
        long timeoutMillis = this.toolExecutionTimeout.toMillis() == 0L ? Integer.MAX_VALUE : this.toolExecutionTimeout.toMillis();
        JsonNode result = null;
        CompletableFuture<JsonNode> resultFuture = null;
        try {
            resultFuture = this.transport.executeOperationWithResponse(operation);
            result = resultFuture.get(timeoutMillis, TimeUnit.MILLISECONDS);
            this.resourceTemplateRefs.set(ResourcesHelper.parseResourceTemplateRefs(result));
        }
        catch (InterruptedException | ExecutionException | TimeoutException e) {
            throw new RuntimeException(e);
        }
        finally {
            this.pendingOperations.remove(operation.getId());
        }
    }

    private synchronized void obtainPromptList() {
        if (this.promptRefs.get() != null) {
            return;
        }
        McpListPromptsRequest operation = new McpListPromptsRequest(this.idGenerator.getAndIncrement());
        long timeoutMillis = this.promptsTimeout.toMillis() == 0L ? Integer.MAX_VALUE : this.promptsTimeout.toMillis();
        JsonNode result = null;
        CompletableFuture<JsonNode> resultFuture = null;
        try {
            resultFuture = this.transport.executeOperationWithResponse(operation);
            result = resultFuture.get(timeoutMillis, TimeUnit.MILLISECONDS);
            this.promptRefs.set(PromptsHelper.parsePromptRefs(result));
        }
        catch (InterruptedException | ExecutionException | TimeoutException e) {
            throw new RuntimeException(e);
        }
        finally {
            this.pendingOperations.remove(operation.getId());
        }
    }

    @Override
    public void close() {
        try {
            this.transport.close();
        }
        catch (Exception e) {
            log.warn("Cannot close MCP transport", (Throwable)e);
        }
    }

    public static class Builder {
        private String toolExecutionTimeoutErrorMessage;
        private McpTransport transport;
        private String key;
        private String clientName;
        private String clientVersion;
        private String protocolVersion;
        private Duration initializationTimeout;
        private Duration toolExecutionTimeout;
        private Duration resourcesTimeout;
        private Duration pingTimeout;
        private Duration promptsTimeout;
        private McpLogMessageHandler logHandler;
        private Duration reconnectInterval;

        public Builder transport(McpTransport transport) {
            this.transport = transport;
            return this;
        }

        public Builder key(String key) {
            this.key = key;
            return this;
        }

        public Builder clientName(String clientName) {
            this.clientName = clientName;
            return this;
        }

        public Builder clientVersion(String clientVersion) {
            this.clientVersion = clientVersion;
            return this;
        }

        public Builder protocolVersion(String protocolVersion) {
            this.protocolVersion = protocolVersion;
            return this;
        }

        public Builder initializationTimeout(Duration initializationTimeout) {
            this.initializationTimeout = initializationTimeout;
            return this;
        }

        public Builder toolExecutionTimeout(Duration toolExecutionTimeout) {
            this.toolExecutionTimeout = toolExecutionTimeout;
            return this;
        }

        public Builder resourcesTimeout(Duration resourcesTimeout) {
            this.resourcesTimeout = resourcesTimeout;
            return this;
        }

        public Builder promptsTimeout(Duration promptsTimeout) {
            this.promptsTimeout = promptsTimeout;
            return this;
        }

        public Builder toolExecutionTimeoutErrorMessage(String toolExecutionTimeoutErrorMessage) {
            this.toolExecutionTimeoutErrorMessage = toolExecutionTimeoutErrorMessage;
            return this;
        }

        public Builder logHandler(McpLogMessageHandler logHandler) {
            this.logHandler = logHandler;
            return this;
        }

        public Builder pingTimeout(Duration pingTimeout) {
            this.pingTimeout = pingTimeout;
            return this;
        }

        public Builder reconnectInterval(Duration reconnectInterval) {
            this.reconnectInterval = reconnectInterval;
            return this;
        }

        public DefaultMcpClient build() {
            return new DefaultMcpClient(this);
        }
    }
}

