/*
 * Decompiled with CFR 0.152.
 */
package io.micronaut.http.server.netty.websocket;

import io.micronaut.context.BeanContext;
import io.micronaut.core.annotation.Internal;
import io.micronaut.core.annotation.NonNull;
import io.micronaut.core.annotation.Nullable;
import io.micronaut.core.execution.ExecutionFlow;
import io.micronaut.core.util.StringUtils;
import io.micronaut.http.HttpAttributes;
import io.micronaut.http.HttpHeaders;
import io.micronaut.http.HttpMethod;
import io.micronaut.http.HttpRequest;
import io.micronaut.http.HttpResponse;
import io.micronaut.http.HttpStatus;
import io.micronaut.http.MutableHttpResponse;
import io.micronaut.http.context.ServerRequestContext;
import io.micronaut.http.exceptions.HttpStatusException;
import io.micronaut.http.netty.NettyHttpHeaders;
import io.micronaut.http.netty.websocket.WebSocketSessionRepository;
import io.micronaut.http.server.RequestLifecycle;
import io.micronaut.http.server.RouteExecutor;
import io.micronaut.http.server.netty.NettyEmbeddedServices;
import io.micronaut.http.server.netty.NettyHttpRequest;
import io.micronaut.http.server.netty.websocket.NettyServerWebSocketHandler;
import io.micronaut.web.router.RouteMatch;
import io.micronaut.web.router.Router;
import io.micronaut.web.router.UriRouteMatch;
import io.micronaut.websocket.CloseReason;
import io.micronaut.websocket.annotation.OnMessage;
import io.micronaut.websocket.annotation.OnOpen;
import io.micronaut.websocket.annotation.ServerWebSocket;
import io.micronaut.websocket.context.WebSocketBean;
import io.micronaut.websocket.context.WebSocketBeanRegistry;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.AsciiString;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Internal
public class NettyServerWebSocketUpgradeHandler
extends SimpleChannelInboundHandler<NettyHttpRequest<?>> {
    public static final String ID = "websocket-upgrade-handler";
    public static final String SCHEME_WEBSOCKET = "ws://";
    public static final String SCHEME_SECURE_WEBSOCKET = "wss://";
    public static final String COMPRESSION_HANDLER = "WebSocketServerCompressionHandler";
    private static final Logger LOG = LoggerFactory.getLogger(NettyServerWebSocketUpgradeHandler.class);
    private static final AsciiString WEB_SOCKET_HEADER_VALUE = AsciiString.cached("websocket");
    private final Router router;
    private final WebSocketBeanRegistry webSocketBeanRegistry;
    private final WebSocketSessionRepository webSocketSessionRepository;
    private final RouteExecutor routeExecutor;
    private final NettyEmbeddedServices nettyEmbeddedServices;
    private WebSocketServerHandshaker handshaker;
    private boolean cancelUpgrade = false;

    public NettyServerWebSocketUpgradeHandler(NettyEmbeddedServices embeddedServices, WebSocketSessionRepository webSocketSessionRepository) {
        this.router = embeddedServices.getRouter();
        this.webSocketBeanRegistry = WebSocketBeanRegistry.forServer((BeanContext)embeddedServices.getApplicationContext());
        this.webSocketSessionRepository = webSocketSessionRepository;
        this.routeExecutor = embeddedServices.getRouteExecutor();
        this.nettyEmbeddedServices = embeddedServices;
    }

    @Override
    public boolean acceptInboundMessage(Object msg) {
        return msg instanceof NettyHttpRequest && this.isWebSocketUpgrade((NettyHttpRequest)msg);
    }

    private boolean isWebSocketUpgrade(@NonNull NettyHttpRequest<?> request) {
        io.netty.handler.codec.http.HttpHeaders headers = request.getNativeRequest().headers();
        if (headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) {
            return headers.containsValue(HttpHeaderNames.UPGRADE, WEB_SOCKET_HEADER_VALUE, true);
        }
        return false;
    }

    @Override
    protected final void channelRead0(ChannelHandlerContext ctx, NettyHttpRequest<?> msg) {
        ServerRequestContext.set(msg);
        Optional<UriRouteMatch> optionalRoute = this.router.find(HttpMethod.GET, msg.getPath(), msg).filter(rm -> rm.isAnnotationPresent(OnMessage.class) || rm.isAnnotationPresent(OnOpen.class)).findFirst();
        WebsocketRequestLifecycle requestLifecycle = new WebsocketRequestLifecycle(this.routeExecutor, msg, optionalRoute.orElse(null));
        ExecutionFlow<MutableHttpResponse> responseFlow = ExecutionFlow.async(ctx.channel().eventLoop(), requestLifecycle::handle);
        responseFlow.onComplete((response, throwable) -> {
            if (response != null) {
                this.writeResponse(ctx, msg, requestLifecycle.shouldProceedNormally, (MutableHttpResponse<?>)response);
            }
        });
    }

    private void writeResponse(ChannelHandlerContext ctx, NettyHttpRequest<?> msg, boolean shouldProceedNormally, MutableHttpResponse<?> actualResponse) {
        if (this.cancelUpgrade) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Cancelling websocket upgrade, handler was removed while request was processing");
            }
            return;
        }
        if (shouldProceedNormally) {
            UriRouteMatch routeMatch = actualResponse.getAttribute(HttpAttributes.ROUTE_MATCH, UriRouteMatch.class).orElseThrow(() -> new IllegalStateException("Route match is required!"));
            WebSocketBean webSocketBean = this.webSocketBeanRegistry.getWebSocket(routeMatch.getTarget().getClass());
            this.handleHandshake(ctx, msg, webSocketBean, actualResponse);
            ChannelPipeline pipeline = ctx.pipeline();
            try {
                NettyServerWebSocketHandler webSocketHandler = new NettyServerWebSocketHandler(this.nettyEmbeddedServices, this.webSocketSessionRepository, this.handshaker, webSocketBean, msg, routeMatch, ctx, this.routeExecutor.getCoroutineHelper().orElse(null));
                pipeline.addBefore(ctx.name(), "websocket-handler", webSocketHandler);
                pipeline.remove("http-streams-codec");
                pipeline.remove(this);
                ChannelHandler accessLoggerHandler = pipeline.get("http-access-logger");
                if (accessLoggerHandler != null) {
                    pipeline.remove(accessLoggerHandler);
                }
            }
            catch (Throwable e) {
                if (LOG.isErrorEnabled()) {
                    LOG.error("Error opening WebSocket: " + e.getMessage(), e);
                }
                ctx.writeAndFlush(new CloseWebSocketFrame(CloseReason.INTERNAL_ERROR.getCode(), CloseReason.INTERNAL_ERROR.getReason()));
            }
        } else {
            ctx.writeAndFlush(actualResponse);
        }
    }

    protected ChannelFuture handleHandshake(ChannelHandlerContext ctx, NettyHttpRequest req, WebSocketBean<?> webSocketBean, MutableHttpResponse<?> response) {
        io.netty.handler.codec.http.HttpHeaders nettyHeaders;
        int maxFramePayloadLength = webSocketBean.messageMethod().map(m -> m.intValue(OnMessage.class, "maxPayloadLength").orElse(65536)).orElse(65536);
        String subprotocols = webSocketBean.getBeanDefinition().stringValue(ServerWebSocket.class, "subprotocols").filter(s -> !StringUtils.isEmpty(s)).orElse(null);
        WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(this.getWebSocketURL(ctx, req), subprotocols, true, maxFramePayloadLength);
        this.handshaker = wsFactory.newHandshaker(req.getNativeRequest());
        HttpHeaders headers = response.getHeaders();
        if (headers instanceof NettyHttpHeaders) {
            nettyHeaders = ((NettyHttpHeaders)headers).getNettyHeaders();
        } else {
            nettyHeaders = new DefaultHttpHeaders();
            for (Map.Entry entry : headers) {
                nettyHeaders.add(entry.getKey(), (Iterable)entry.getValue());
            }
        }
        Channel channel = ctx.channel();
        if (this.handshaker == null) {
            return WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(channel);
        }
        return this.handshaker.handshake(channel, req.getNativeRequest(), nettyHeaders, channel.newPromise());
    }

    protected String getWebSocketURL(ChannelHandlerContext ctx, HttpRequest req) {
        boolean isSecure = ctx.pipeline().get(SslHandler.class) != null;
        return (isSecure ? SCHEME_SECURE_WEBSOCKET : SCHEME_WEBSOCKET) + (String)req.getHeaders().get(HttpHeaderNames.HOST) + req.getUri();
    }

    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
        super.handlerRemoved(ctx);
        this.cancelUpgrade = true;
    }

    @Override
    public void channelInactive(@NonNull ChannelHandlerContext ctx) throws Exception {
        super.channelInactive(ctx);
        this.cancelUpgrade = true;
    }

    private static final class WebsocketRequestLifecycle
    extends RequestLifecycle {
        @Nullable
        final RouteMatch<?> route;
        boolean shouldProceedNormally;

        WebsocketRequestLifecycle(RouteExecutor routeExecutor, HttpRequest<?> request, @Nullable RouteMatch<?> route) {
            super(routeExecutor, request);
            this.route = route;
        }

        ExecutionFlow<MutableHttpResponse<?>> handle() {
            MutableHttpResponse proceed = HttpResponse.ok();
            if (this.route != null) {
                this.request().setAttribute(HttpAttributes.ROUTE_MATCH, this.route);
                this.request().setAttribute(HttpAttributes.ROUTE_INFO, this.route);
                proceed.setAttribute(HttpAttributes.ROUTE_MATCH, this.route);
                proceed.setAttribute(HttpAttributes.ROUTE_INFO, this.route);
            }
            ExecutionFlow<MutableHttpResponse<?>> response = this.route != null ? this.runWithFilters(() -> ExecutionFlow.just(proceed)) : this.onError(new HttpStatusException(HttpStatus.NOT_FOUND, "WebSocket Not Found")).putInContext("micronaut.http.server.request", this.request());
            return response.map(r -> {
                if (r == proceed) {
                    this.shouldProceedNormally = true;
                }
                return r;
            });
        }
    }
}

