/*
 * Decompiled with CFR 0.152.
 */
package com.mulesoft.connectors.mcp.internal.server.connection.provider.streamable;

import com.fasterxml.jackson.core.type.TypeReference;
import com.mulesoft.connectors.mcp.api.server.StreamableMimeType;
import com.mulesoft.connectors.mcp.internal.McpUtils;
import com.mulesoft.connectors.mcp.internal.error.McpErrorTypes;
import com.mulesoft.connectors.mcp.internal.error.exception.SessionRejectedException;
import com.mulesoft.connectors.mcp.internal.server.connection.MuleServerSession;
import com.mulesoft.connectors.mcp.internal.server.connection.observer.InboundRequestContext;
import com.mulesoft.connectors.mcp.internal.server.connection.observer.InternalNewSessionRequest;
import com.mulesoft.connectors.mcp.internal.server.connection.observer.SessionObserver;
import com.mulesoft.connectors.mcp.internal.server.connection.provider.BaseServerTransportProvider;
import com.mulesoft.connectors.mcp.internal.server.session.SessionManager;
import com.mulesoft.connectors.mcp.internal.util.HttpTransportUtils;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerTransport;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;
import org.mule.runtime.api.scheduler.Scheduler;
import org.mule.runtime.api.util.MultiMap;
import org.mule.runtime.core.api.util.StringUtils;
import org.mule.runtime.http.api.HttpConstants;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.request.HttpRequestContext;
import org.mule.runtime.http.api.server.HttpServer;
import org.mule.runtime.http.api.server.RequestHandler;
import org.mule.runtime.http.api.server.RequestHandlerManager;
import org.mule.runtime.http.api.server.async.HttpResponseReadyCallback;
import org.mule.runtime.http.api.sse.server.SseClient;
import org.mule.runtime.http.api.sse.server.SseClientConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;

public class StreamableHttpServerTransportProvider
extends BaseServerTransportProvider {
    private static final Logger LOGGER = LoggerFactory.getLogger(StreamableHttpServerTransportProvider.class);
    private final HttpServer httpServer;
    private final String mcpEndpointPath;
    private final Scheduler scheduler;
    private final MultiMap<String, String> defaultResponseHeaders;
    private final StreamableMimeType streamableMimeType;
    private RequestHandlerManager postHandlerManager;
    private RequestHandlerManager getHandlerManager;
    private RequestHandlerManager deleteHandlerManager;

    public StreamableHttpServerTransportProvider(String refName, HttpServer httpServer, Scheduler scheduler, String mcpEndpointPath, StreamableMimeType streamableMimeType, SessionManager sessionManager, MultiMap<String, String> defaultResponseHeaders) {
        super(refName, sessionManager);
        this.httpServer = httpServer;
        this.scheduler = scheduler;
        this.streamableMimeType = streamableMimeType;
        this.mcpEndpointPath = HttpTransportUtils.normalizePath(mcpEndpointPath, "mcpEndpointPath");
        this.defaultResponseHeaders = defaultResponseHeaders != null ? defaultResponseHeaders.toImmutableMultiMap() : MultiMap.emptyMultiMap();
    }

    @Override
    protected void doOpen() {
        this.postHandlerManager = this.httpServer.addRequestHandler(List.of("POST"), this.mcpEndpointPath, this.onPost());
        this.postHandlerManager.start();
        this.getHandlerManager = this.httpServer.addRequestHandler(List.of("GET"), this.mcpEndpointPath, this.onGet());
        this.getHandlerManager.start();
        this.deleteHandlerManager = this.httpServer.addRequestHandler(List.of("DELETE"), this.mcpEndpointPath, this.onDelete());
        this.deleteHandlerManager.start();
    }

    @Override
    protected void doCloseGracefully() {
        this.close(this.postHandlerManager, "POST");
        this.close(this.getHandlerManager, "GET");
        this.close(this.deleteHandlerManager, "DELETE");
    }

    private void close(RequestHandlerManager handlerManager, String method) {
        if (handlerManager != null) {
            try {
                handlerManager.stop();
                handlerManager.dispose();
            }
            catch (Exception e) {
                LOGGER.error("Exception found stopping {} endpoint of config {}", new Object[]{method, this.refName, e});
            }
        }
    }

    private RequestHandler onPost() {
        return (requestContext, responseCallback) -> {
            if (!this.assureTransportOpen(responseCallback)) {
                return;
            }
            LOGGER.info("Handling post message");
            this.scheduler.submit(() -> {
                MuleServerSession session;
                McpSchema.JSONRPCMessage message;
                HttpRequest request = requestContext.getRequest();
                if (!this.acceptsValidMimeType(request)) {
                    McpSchema.JSONRPCResponse rpcResponse = McpUtils.rpcErrorResponse(null, McpErrorTypes.INVALID_REQUEST, "Client MUST accept either %s, %s or both mediaTypes".formatted("application/json", "text/event-stream"));
                    HttpTransportUtils.sendHttpResponse(HttpConstants.HttpStatus.BAD_REQUEST.getStatusCode(), (McpSchema.JSONRPCMessage)rpcResponse, responseCallback);
                    return;
                }
                try {
                    message = HttpTransportUtils.parseMessageFromBody(request);
                }
                catch (IOException e) {
                    HttpTransportUtils.sendHttpResponse(HttpConstants.HttpStatus.BAD_REQUEST.getStatusCode(), (McpSchema.JSONRPCMessage)McpUtils.rpcErrorResponse(null, McpErrorTypes.PARSE_ERROR, e.getMessage()), responseCallback);
                    return;
                }
                String requestSessionId = request.getHeaderValue("Mcp-Session-Id");
                Transport transport = new Transport(requestContext, responseCallback);
                if (requestSessionId != null) {
                    MuleServerSession recoveredSession = this.sessionManager.recoverSession(requestSessionId).orElse(null);
                    if (recoveredSession == null) {
                        HttpTransportUtils.sendHttpResponse(HttpConstants.HttpStatus.BAD_REQUEST.getStatusCode(), "SessionId not found", responseCallback);
                        return;
                    }
                    session = recoveredSession.restoreFrom(this.getSessionFactory().create((McpServerTransport)transport), transport);
                } else {
                    try {
                        session = this.createAndRegisterNewSession(transport);
                    }
                    catch (SessionRejectedException e) {
                        HttpTransportUtils.sendHttpResponse(e.getStatusCode(), e.getMessage(), responseCallback);
                        return;
                    }
                }
                this.getRequestObserver().ifPresent(observer -> HttpTransportUtils.asInboundRequestContext(message, session.getId(), requestContext, this.httpServer).ifPresent(ctx -> observer.onRequest((InboundRequestContext)ctx)));
                boolean isJsonRpcRequest = message instanceof McpSchema.JSONRPCRequest;
                if (isJsonRpcRequest) {
                    transport.open(session);
                }
                session.handle(message).doOnSuccess(v -> {
                    if (!isJsonRpcRequest) {
                        HttpTransportUtils.sendHttpResponse(HttpConstants.HttpStatus.ACCEPTED.getStatusCode(), "", this.defaultResponseHeaders, responseCallback);
                    }
                }).doOnError(t -> {
                    LOGGER.error("Error processing message in session {} of config {}", new Object[]{session.getId(), this.refName, t});
                    HttpTransportUtils.sendHttpInternalErrorResponse(t, message, responseCallback);
                }).doAfterTerminate(session::close).subscribe();
            });
        };
    }

    private MuleServerSession createAndRegisterNewSession(Transport transport) throws SessionRejectedException {
        MuleServerSession session = new MuleServerSession(this.getSessionFactory().create((McpServerTransport)transport), transport);
        session.onInitialize(() -> this.sessionManager.upsert(session));
        SessionObserver sessionObserver = this.getSessionObserver().orElse(null);
        if (sessionObserver != null) {
            try {
                InternalNewSessionRequest sessionResponse = sessionObserver.onNewSessionRequest(HttpTransportUtils.createNewSessionRequest(session, transport.getRequestContext())).get();
                if (sessionResponse.getRejectedStatusCode() != null) {
                    throw new SessionRejectedException(sessionResponse.getRejectedMessage(), sessionResponse.getRejectedStatusCode());
                }
            }
            catch (Throwable t) {
                if (t instanceof ExecutionException) {
                    t = t.getCause();
                }
                LOGGER.error("Exception found while processing new session request. Connection will be closed", t);
                throw new SessionRejectedException(t.getMessage(), HttpConstants.HttpStatus.INTERNAL_SERVER_ERROR.getStatusCode(), t);
            }
        }
        this.sessionManager.upsert(session);
        return session;
    }

    private RequestHandler onGet() {
        return (requestContext, responseCallback) -> HttpTransportUtils.sendHttpResponse(HttpConstants.HttpStatus.METHOD_NOT_ALLOWED.getStatusCode(), "", responseCallback);
    }

    private RequestHandler onDelete() {
        return (requestContext, responseCallback) -> {
            String requestSessionId = requestContext.getRequest().getHeaderValue("Mcp-Session-Id");
            if (StringUtils.isBlank((String)requestSessionId)) {
                HttpTransportUtils.sendHttpResponse(HttpConstants.HttpStatus.BAD_REQUEST.getStatusCode(), "Invalid SessionId", responseCallback);
            } else {
                int statusCode = this.sessionManager.unregisterAndClose(requestSessionId) ? HttpConstants.HttpStatus.ACCEPTED.getStatusCode() : HttpConstants.HttpStatus.NOT_FOUND.getStatusCode();
                HttpTransportUtils.sendHttpResponse(statusCode, "", responseCallback);
            }
        };
    }

    private boolean acceptsValidMimeType(HttpRequest request) {
        String requiredMimeType = this.streamableMimeType == StreamableMimeType.SSE ? "text/event-stream" : "application/json";
        for (String contentType : request.getHeaderValues("Accept")) {
            if (contentType == null || !contentType.contains(requiredMimeType)) continue;
            return true;
        }
        return false;
    }

    private class Transport
    implements McpServerTransport {
        private final AtomicBoolean open = new AtomicBoolean(false);
        private final AtomicBoolean completed = new AtomicBoolean(false);
        private final HttpRequestContext requestContext;
        private final HttpResponseReadyCallback responseCallback;
        private SseClient sseClient;
        private MuleServerSession session;

        private Transport(HttpRequestContext requestContext, HttpResponseReadyCallback responseCallback) {
            this.requestContext = requestContext;
            this.responseCallback = responseCallback;
        }

        public void open(MuleServerSession session) {
            if (!this.open.compareAndSet(false, true)) {
                return;
            }
            this.session = session;
            if (StreamableHttpServerTransportProvider.this.streamableMimeType == StreamableMimeType.SSE) {
                this.sseClient = this.responseCallback.startSseResponse(SseClientConfig.builderFrom((HttpRequestContext)this.requestContext).withClientId("sse-server-" + StreamableHttpServerTransportProvider.this.refName + "-" + session.getId()).customizeResponse(customizer -> {
                    StreamableHttpServerTransportProvider.this.defaultResponseHeaders.entryList().forEach(entry -> customizer.addResponseHeader((String)entry.getKey(), (String)entry.getValue()));
                    customizer.addResponseHeader("Mcp-Session-Id", session.getId());
                }).build());
                this.sseClient.onClose(t -> {
                    this.completed.set(true);
                    this.sseClient = null;
                });
            }
        }

        public Mono<Void> closeGracefully() {
            if (this.open.compareAndSet(true, false)) {
                if (StreamableHttpServerTransportProvider.this.streamableMimeType == StreamableMimeType.SSE) {
                    this.closeSseClient();
                } else if (this.completed.compareAndSet(false, true)) {
                    HttpTransportUtils.sendHttpResponse(HttpConstants.HttpStatus.SERVICE_UNAVAILABLE.getStatusCode(), "Service stopped", StreamableHttpServerTransportProvider.this.defaultResponseHeaders, this.responseCallback);
                }
            }
            return Mono.empty();
        }

        private void closeSseClient() {
            if (this.sseClient != null) {
                try {
                    LOGGER.debug("Closing SSE client with client id: {}", (Object)this.sseClient.getClientId());
                    this.sseClient.close();
                }
                catch (Exception e) {
                    LOGGER.debug("Error closing SSE client for session id {}", (Object)this.session.getId(), (Object)e);
                }
            }
        }

        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
            if (!this.open.get()) {
                LOGGER.debug("Attempted to send message on session {} but server transport stopped. Ignoring...", (Object)this.session.getId());
                return Mono.empty();
            }
            if (StreamableHttpServerTransportProvider.this.streamableMimeType == StreamableMimeType.SSE) {
                try {
                    HttpTransportUtils.sendMessageEvent(this.sseClient, message);
                }
                finally {
                    if (message instanceof McpSchema.JSONRPCResponse) {
                        this.completed.set(true);
                        LOGGER.debug("Response emitted for session {}. Closing SSE connection", (Object)this.session.getId());
                        this.closeSseClient();
                    }
                }
            } else {
                McpSchema.JSONRPCResponse response;
                MultiMap headers = new MultiMap(StreamableHttpServerTransportProvider.this.defaultResponseHeaders);
                headers.put((Object)"Content-Type", (Object)"application/json");
                if (message instanceof McpSchema.JSONRPCResponse && (response = (McpSchema.JSONRPCResponse)message).result() instanceof McpSchema.InitializeResult) {
                    headers.put((Object)"Mcp-Session-Id", (Object)this.session.getId());
                }
                LOGGER.debug("Sending JSON message for session {}", (Object)this.session.getId());
                HttpTransportUtils.sendHttpResponse(HttpConstants.HttpStatus.OK.getStatusCode(), message, (MultiMap<String, String>)headers, HttpTransportUtils.afterSent(this.responseCallback, () -> this.completed.set(true)));
            }
            return Mono.empty();
        }

        public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
            return (T)objectMapper.convertValue(data, typeRef);
        }

        private HttpRequestContext getRequestContext() {
            return this.requestContext;
        }
    }
}

