/*
 * Decompiled with CFR 0.152.
 */
package com.mulesoft.connectors.mcp.internal.server.connection.provider.sse;

import com.fasterxml.jackson.core.type.TypeReference;
import com.mulesoft.connectors.mcp.internal.McpUtils;
import com.mulesoft.connectors.mcp.internal.error.McpErrorTypes;
import com.mulesoft.connectors.mcp.internal.server.connection.MuleServerSession;
import com.mulesoft.connectors.mcp.internal.server.connection.observer.InboundRequestContext;
import com.mulesoft.connectors.mcp.internal.server.connection.observer.InternalNewSessionRequest;
import com.mulesoft.connectors.mcp.internal.server.connection.observer.SessionObserver;
import com.mulesoft.connectors.mcp.internal.server.connection.provider.BaseServerTransportProvider;
import com.mulesoft.connectors.mcp.internal.server.session.InMemorySessionManager;
import com.mulesoft.connectors.mcp.internal.util.HttpTransportUtils;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerTransport;
import java.util.List;
import java.util.concurrent.ExecutionException;
import org.mule.runtime.api.scheduler.Scheduler;
import org.mule.runtime.api.util.MultiMap;
import org.mule.runtime.core.api.util.StringUtils;
import org.mule.runtime.http.api.HttpConstants;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.server.HttpServer;
import org.mule.runtime.http.api.server.RequestHandler;
import org.mule.runtime.http.api.server.RequestHandlerManager;
import org.mule.runtime.http.api.sse.server.SseClient;
import org.mule.runtime.http.api.sse.server.SseEndpointManager;
import org.mule.runtime.http.api.sse.server.SseRequestContext;
import org.mule.sdk.api.error.ErrorTypeDefinition;
import org.mule.sdk.api.exception.ModuleException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;

public class SSEServerTransportProvider
extends BaseServerTransportProvider {
    private static final Logger LOGGER = LoggerFactory.getLogger(SSEServerTransportProvider.class);
    private static final String SESSION_ID_PARAM = "sessionId";
    private static final String ENDPOINT_EVENT_NAME = "endpoint";
    private final HttpServer httpServer;
    private final String connectionEndpointPath;
    private final String messagesEndpointPath;
    private final Scheduler scheduler;
    private final MultiMap<String, String> defaultSseResponseHeaders;
    private final MultiMap<String, String> defaultMessageResponseHeaders;
    private SseEndpointManager sseEndpointManager;
    private RequestHandlerManager postHandlerManager;

    public SSEServerTransportProvider(String refName, HttpServer httpServer, Scheduler scheduler, String connectionEndpointPath, String messagesEndpointPath, MultiMap<String, String> defaultSseResponseHeaders, MultiMap<String, String> defaultMessageResponseHeaders) {
        super(refName, new InMemorySessionManager());
        this.httpServer = httpServer;
        this.scheduler = scheduler;
        this.connectionEndpointPath = HttpTransportUtils.normalizePath(connectionEndpointPath, "connectionEndpointPath");
        this.messagesEndpointPath = HttpTransportUtils.normalizePath(messagesEndpointPath, "messagesEndpointPath");
        this.defaultSseResponseHeaders = McpUtils.immutable(defaultSseResponseHeaders);
        this.defaultMessageResponseHeaders = McpUtils.immutable(defaultMessageResponseHeaders);
    }

    @Override
    protected void doOpen() {
        this.sseEndpointManager = this.httpServer.sse(this.connectionEndpointPath, this::onSSERequest, this::onNewSseClient);
        this.sseEndpointManager.start();
        this.postHandlerManager = this.httpServer.addRequestHandler(List.of("POST"), this.messagesEndpointPath, this.onPostMessage());
        this.postHandlerManager.start();
    }

    @Override
    public void doCloseGracefully() {
        if (this.sseEndpointManager != null) {
            try {
                this.sseEndpointManager.stop();
                this.sseEndpointManager.dispose();
            }
            catch (Exception e) {
                LOGGER.error("Exception found stopping SSE endpoint of config {}", (Object)this.refName, (Object)e);
            }
        }
        if (this.postHandlerManager != null) {
            try {
                this.postHandlerManager.stop();
                this.postHandlerManager.dispose();
            }
            catch (Exception e) {
                LOGGER.error("Exception found stopping message endpoint", (Throwable)e);
            }
        }
    }

    private RequestHandler onPostMessage() {
        return (requestContext, responseCallback) -> {
            if (!this.assureTransportOpen(responseCallback)) {
                return;
            }
            LOGGER.info("Handling post message");
            this.scheduler.submit(() -> {
                McpSchema.JSONRPCMessage message;
                HttpRequest request = requestContext.getRequest();
                String contentType = request.getHeaderValue("Content-Type");
                if (!"application/json".equals(contentType)) {
                    HttpTransportUtils.sendHttpResponse(HttpConstants.HttpStatus.BAD_REQUEST.getStatusCode(), "Unsupported content type: " + contentType, responseCallback);
                    return;
                }
                String sessionId = (String)request.getQueryParams().get((Object)SESSION_ID_PARAM);
                if (StringUtils.isBlank((String)sessionId)) {
                    HttpTransportUtils.sendHttpResponse(HttpConstants.HttpStatus.BAD_REQUEST.getStatusCode(), "Missing session id", responseCallback);
                    return;
                }
                MuleServerSession session = this.sessionManager.recoverSession(sessionId).orElse(null);
                if (session == null) {
                    HttpTransportUtils.sendHttpResponse(HttpConstants.HttpStatus.SERVICE_UNAVAILABLE.getStatusCode(), "No active SSE connection for sessionId: " + sessionId, responseCallback);
                    return;
                }
                try {
                    message = HttpTransportUtils.parseMessageFromBody(request);
                }
                catch (Exception e) {
                    HttpTransportUtils.sendHttpResponse(HttpConstants.HttpStatus.BAD_REQUEST.getStatusCode(), "Invalid message format", responseCallback);
                    return;
                }
                this.getRequestObserver().ifPresent(observer -> HttpTransportUtils.asInboundRequestContext(message, sessionId, requestContext).ifPresent(ctx -> observer.onRequest((InboundRequestContext)ctx)));
                session.handle(message).doOnSuccess(v -> HttpTransportUtils.sendHttpResponse(HttpConstants.HttpStatus.ACCEPTED.getStatusCode(), HttpConstants.HttpStatus.ACCEPTED.getReasonPhrase(), this.defaultMessageResponseHeaders, responseCallback)).doOnError(t -> {
                    LOGGER.error("Error processing message in session {} of config {}", new Object[]{session.getId(), this.refName, t});
                    HttpTransportUtils.sendHttpInternalErrorResponse(t, message, responseCallback);
                }).subscribe();
            });
        };
    }

    private void onSSERequest(SseRequestContext ctx) {
        LOGGER.info("Handling new SSE connection");
        SseTransport transport = new SseTransport();
        MuleServerSession session = new MuleServerSession(this.getSessionFactory().create((McpServerTransport)transport), transport);
        SessionObserver sessionObserver = this.getSessionObserver().orElse(null);
        if (sessionObserver != null) {
            try {
                InternalNewSessionRequest request = sessionObserver.onNewSessionRequest(HttpTransportUtils.createNewSessionRequest(session, ctx.getRequestContext())).get();
                if (request.getRejectedStatusCode() != null) {
                    ctx.reject(request.getRejectedStatusCode().intValue(), request.getRejectedMessage());
                    session.close();
                    return;
                }
            }
            catch (Throwable t) {
                if (t instanceof ExecutionException) {
                    t = t.getCause();
                }
                LOGGER.error("Exception found while processing new connection request. Connection will be closed", t);
                ctx.reject(HttpConstants.HttpStatus.INTERNAL_SERVER_ERROR.getStatusCode(), t.getMessage());
                session.close();
                return;
            }
        }
        ctx.customizeResponse(customizer -> this.defaultSseResponseHeaders.entryList().forEach(entry -> customizer.addResponseHeader((String)entry.getKey(), (String)entry.getValue())));
        ctx.setClientId(session.getId());
        this.sessionManager.upsert(session);
    }

    private void onNewSseClient(SseClient client) {
        String sessionId = client.getClientId();
        MuleServerSession session = this.sessionManager.recoverSession(sessionId).orElseThrow(() -> new ModuleException("Newly created session %s of config '%s' lost".formatted(sessionId, this.refName), (ErrorTypeDefinition)McpErrorTypes.INTERNAL_ERROR));
        ((SseTransport)session.getTransport()).setClient(client);
        client.onClose(t -> this.sessionManager.unregisterAndClose(sessionId));
        HttpTransportUtils.sendMessageEvent(client, ENDPOINT_EVENT_NAME, this.messagesEndpointPath + "?sessionId=" + sessionId);
    }

    private class SseTransport
    implements McpServerTransport {
        private SseClient client;

        private SseTransport() {
        }

        public Mono<Void> closeGracefully() {
            if (this.client != null) {
                try {
                    this.client.close();
                }
                catch (Exception e) {
                    LOGGER.error("Exception closing SSE client for session {} of config {}", new Object[]{this.client.getClientId(), SSEServerTransportProvider.this.refName, e});
                }
            }
            return Mono.empty();
        }

        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
            if (!SSEServerTransportProvider.this.open.get()) {
                LOGGER.debug("Attempted to send message on session {} but server transport stopped. Ignoring...", (Object)this.client.getClientId());
                return Mono.empty();
            }
            try {
                HttpTransportUtils.sendMessageEvent(this.client, message);
                return Mono.empty();
            }
            catch (Throwable t) {
                return Mono.error((Throwable)t);
            }
        }

        public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
            return (T)objectMapper.convertValue(data, typeRef);
        }

        public void setClient(SseClient client) {
            this.client = client;
        }
    }
}

