/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.mcp.runtime.http;

import com.fasterxml.jackson.databind.JsonNode;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.mcp.client.protocol.InitializationNotification;
import dev.langchain4j.mcp.client.protocol.McpClientMessage;
import dev.langchain4j.mcp.client.protocol.McpInitializeRequest;
import dev.langchain4j.mcp.client.transport.McpOperationHandler;
import dev.langchain4j.mcp.client.transport.McpTransport;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
import io.quarkiverse.langchain4j.mcp.auth.McpClientAuthProvider;
import io.quarkiverse.langchain4j.mcp.runtime.http.McpClientAuthFilter;
import io.quarkiverse.langchain4j.mcp.runtime.http.McpHttpClientLogger;
import io.quarkiverse.langchain4j.mcp.runtime.http.McpPostEndpoint;
import io.quarkiverse.langchain4j.mcp.runtime.http.McpSseEndpoint;
import io.quarkiverse.langchain4j.mcp.runtime.http.SseSubscriber;
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import io.smallrye.mutiny.Uni;
import java.io.IOException;
import java.net.URI;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import org.jboss.logging.Logger;
import org.jboss.resteasy.reactive.client.api.ClientLogger;
import org.jboss.resteasy.reactive.client.api.LoggingScope;
import org.jboss.resteasy.reactive.server.jackson.JacksonBasicMessageBodyReader;

public class QuarkusHttpMcpTransport
implements McpTransport {
    private static final Logger log = Logger.getLogger(QuarkusHttpMcpTransport.class);
    private final String sseUrl;
    private final McpSseEndpoint sseEndpoint;
    private final Duration timeout;
    private final boolean logResponses;
    private final boolean logRequests;
    private volatile String postUrl;
    private volatile McpPostEndpoint postEndpoint;
    private volatile McpOperationHandler operationHandler;
    private final McpClientAuthProvider mcpClientAuthProvider;
    private volatile Runnable onFailure;
    private volatile boolean closed;

    public QuarkusHttpMcpTransport(Builder builder) {
        this.sseUrl = (String)ValidationUtils.ensureNotNull((Object)builder.sseUrl, (String)"Missing SSE endpoint URL");
        this.timeout = (Duration)Utils.getOrDefault((Object)builder.timeout, (Object)Duration.ofSeconds(60L));
        this.logRequests = builder.logRequests;
        this.logResponses = builder.logResponses;
        QuarkusRestClientBuilder clientBuilder = (QuarkusRestClientBuilder)QuarkusRestClientBuilder.newBuilder().baseUri(URI.create(builder.sseUrl)).connectTimeout(this.timeout.toSeconds(), TimeUnit.SECONDS).readTimeout(this.timeout.toSeconds(), TimeUnit.SECONDS).loggingScope(LoggingScope.ALL).register((Object)new JacksonBasicMessageBodyReader(QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER));
        this.mcpClientAuthProvider = McpClientAuthProvider.resolve(builder.mcpClientName).orElse(null);
        if (this.mcpClientAuthProvider != null) {
            clientBuilder.register((Object)new McpClientAuthFilter(this.mcpClientAuthProvider));
        }
        if (this.logRequests || this.logResponses) {
            clientBuilder.loggingScope(LoggingScope.REQUEST_RESPONSE);
            clientBuilder.clientLogger((ClientLogger)new McpHttpClientLogger(this.logRequests, this.logResponses));
        }
        this.sseEndpoint = (McpSseEndpoint)clientBuilder.build(McpSseEndpoint.class);
    }

    public void start(McpOperationHandler messageHandler) {
        this.operationHandler = messageHandler;
        this.startSseChannel(this.logResponses);
        QuarkusRestClientBuilder builder = (QuarkusRestClientBuilder)QuarkusRestClientBuilder.newBuilder().baseUri(URI.create(this.postUrl)).connectTimeout(this.timeout.toSeconds(), TimeUnit.SECONDS).readTimeout(this.timeout.toSeconds(), TimeUnit.SECONDS).register((Object)new JacksonBasicMessageBodyReader(QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER));
        if (this.mcpClientAuthProvider != null) {
            builder.register((Object)new McpClientAuthFilter(this.mcpClientAuthProvider));
        }
        if (this.logRequests || this.logResponses) {
            builder.loggingScope(LoggingScope.REQUEST_RESPONSE);
            builder.clientLogger((ClientLogger)new McpHttpClientLogger(this.logRequests, this.logResponses));
        }
        this.postEndpoint = (McpPostEndpoint)builder.build(McpPostEndpoint.class);
    }

    public CompletableFuture<JsonNode> initialize(McpInitializeRequest request) {
        return this.execute((McpClientMessage)request, request.getId()).onItem().transformToUni(response -> this.execute((McpClientMessage)new InitializationNotification(), null).onItem().transform(ignored -> response)).subscribeAsCompletionStage();
    }

    public void checkHealth() {
    }

    public void onFailure(Runnable actionOnFailure) {
        this.onFailure = actionOnFailure;
    }

    public CompletableFuture<JsonNode> executeOperationWithResponse(McpClientMessage operation) {
        return this.execute(operation, operation.getId()).subscribeAsCompletionStage();
    }

    public void executeOperationWithoutResponse(McpClientMessage operation) {
        this.execute(operation, null).subscribe().with(ignored -> {});
    }

    private Uni<JsonNode> execute(McpClientMessage request, Long id) {
        CompletableFuture future = new CompletableFuture();
        Uni uni = Uni.createFrom().completionStage(future);
        if (id != null) {
            this.operationHandler.startOperation(id, future);
        }
        this.postEndpoint.post(request).onFailure().invoke(future::completeExceptionally).onItem().invoke(response -> {
            int statusCode = response.getStatus();
            if (!this.isExpectedStatusCode(statusCode)) {
                future.completeExceptionally(new RuntimeException("Unexpected status code: " + statusCode));
            }
            if (id == null) {
                future.complete(null);
            }
        }).subscribeAsCompletionStage();
        return uni;
    }

    private boolean isExpectedStatusCode(int statusCode) {
        return statusCode >= 200 && statusCode < 300;
    }

    private void startSseChannel(boolean logResponses) {
        CompletableFuture<String> initializationFinished = new CompletableFuture<String>();
        SseSubscriber listener = new SseSubscriber(this.operationHandler, logResponses, initializationFinished);
        this.sseEndpoint.get().subscribe().with((Consumer)listener, throwable -> {
            if (!initializationFinished.isDone()) {
                log.warn((Object)"Failed to connect to the SSE channel, the MCP client will not be used", throwable);
                initializationFinished.completeExceptionally((Throwable)throwable);
            }
            if (!this.closed) {
                this.onFailure.run();
            }
        });
        try {
            long timeoutMillis = this.timeout.toMillis() > 0L ? this.timeout.toMillis() : Integer.MAX_VALUE;
            String relativePostUrl = initializationFinished.get(timeoutMillis, TimeUnit.MILLISECONDS);
            this.postUrl = this.buildAbsolutePostUrl(relativePostUrl);
            log.debug((Object)("Received the server's POST URL: " + this.postUrl));
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private String buildAbsolutePostUrl(String relativePostUrl) {
        try {
            return URI.create(this.sseUrl).resolve(relativePostUrl).toString();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void close() throws IOException {
        this.closed = true;
    }

    public static class Builder {
        private String sseUrl;
        private String mcpClientName;
        private Duration timeout;
        private boolean logRequests = false;
        private boolean logResponses = false;

        public Builder sseUrl(String sseUrl) {
            this.sseUrl = sseUrl;
            return this;
        }

        public Builder mcpClientName(String mcpClientName) {
            this.mcpClientName = mcpClientName;
            return this;
        }

        public Builder timeout(Duration timeout) {
            this.timeout = timeout;
            return this;
        }

        public Builder logRequests(boolean logRequests) {
            this.logRequests = logRequests;
            return this;
        }

        public Builder logResponses(boolean logResponses) {
            this.logResponses = logResponses;
            return this;
        }

        public QuarkusHttpMcpTransport build() {
            return new QuarkusHttpMcpTransport(this);
        }
    }
}

