/*
 * Decompiled with CFR 0.152.
 */
package graphql.kickstart.servlet;

import graphql.kickstart.execution.GraphQLInvoker;
import graphql.kickstart.execution.GraphQLObjectMapper;
import graphql.kickstart.execution.subscriptions.GraphQLSubscriptionInvocationInputFactory;
import graphql.kickstart.execution.subscriptions.GraphQLSubscriptionMapper;
import graphql.kickstart.execution.subscriptions.SessionSubscriptions;
import graphql.kickstart.execution.subscriptions.SubscriptionConnectionListener;
import graphql.kickstart.execution.subscriptions.SubscriptionProtocolFactory;
import graphql.kickstart.execution.subscriptions.SubscriptionSession;
import graphql.kickstart.execution.subscriptions.apollo.ApolloSubscriptionConnectionListener;
import graphql.kickstart.servlet.GraphQLConfiguration;
import graphql.kickstart.servlet.apollo.ApolloWebSocketSubscriptionProtocolFactory;
import graphql.kickstart.servlet.subscriptions.FallbackSubscriptionProtocolFactory;
import graphql.kickstart.servlet.subscriptions.WebSocketSendSubscriber;
import graphql.kickstart.servlet.subscriptions.WebSocketSubscriptionProtocolFactory;
import java.io.EOFException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.websocket.CloseReason;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.HandshakeResponse;
import javax.websocket.MessageHandler;
import javax.websocket.Session;
import javax.websocket.server.HandshakeRequest;
import javax.websocket.server.ServerEndpointConfig;
import lombok.Generated;
import org.reactivestreams.Subscriber;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GraphQLWebsocketServlet
extends Endpoint {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(GraphQLWebsocketServlet.class);
    private static final String HANDSHAKE_REQUEST_KEY = HandshakeRequest.class.getName();
    private static final String PROTOCOL_FACTORY_REQUEST_KEY = SubscriptionProtocolFactory.class.getName();
    private static final CloseReason ERROR_CLOSE_REASON = new CloseReason((CloseReason.CloseCode)CloseReason.CloseCodes.UNEXPECTED_CONDITION, "Internal Server Error");
    private static final CloseReason SHUTDOWN_CLOSE_REASON = new CloseReason((CloseReason.CloseCode)CloseReason.CloseCodes.UNEXPECTED_CONDITION, "Server Shut Down");
    private final List<SubscriptionProtocolFactory> subscriptionProtocolFactories;
    private final SubscriptionProtocolFactory fallbackSubscriptionProtocolFactory;
    private final List<String> allSubscriptionProtocols;
    private final Map<Session, SessionSubscriptions> sessionSubscriptionCache = new ConcurrentHashMap<Session, SessionSubscriptions>();
    private final AtomicBoolean isShuttingDown = new AtomicBoolean(false);
    private final AtomicBoolean isShutDown = new AtomicBoolean(false);
    private final Object cacheLock = new Object();

    public GraphQLWebsocketServlet(GraphQLConfiguration configuration) {
        this(configuration, null);
    }

    public GraphQLWebsocketServlet(GraphQLConfiguration configuration, Collection<SubscriptionConnectionListener> connectionListeners) {
        this(configuration.getGraphQLInvoker(), configuration.getInvocationInputFactory(), configuration.getObjectMapper(), connectionListeners);
    }

    public GraphQLWebsocketServlet(GraphQLInvoker graphQLInvoker, GraphQLSubscriptionInvocationInputFactory invocationInputFactory, GraphQLObjectMapper graphQLObjectMapper) {
        this(graphQLInvoker, invocationInputFactory, graphQLObjectMapper, null);
    }

    public GraphQLWebsocketServlet(GraphQLInvoker graphQLInvoker, GraphQLSubscriptionInvocationInputFactory invocationInputFactory, GraphQLObjectMapper graphQLObjectMapper, Collection<SubscriptionConnectionListener> connectionListeners) {
        ArrayList<ApolloSubscriptionConnectionListener> listeners = new ArrayList<ApolloSubscriptionConnectionListener>();
        if (connectionListeners != null) {
            connectionListeners.stream().filter(ApolloSubscriptionConnectionListener.class::isInstance).map(ApolloSubscriptionConnectionListener.class::cast).forEach(listeners::add);
        }
        this.subscriptionProtocolFactories = Collections.singletonList(new ApolloWebSocketSubscriptionProtocolFactory(graphQLObjectMapper, invocationInputFactory, graphQLInvoker, listeners));
        this.fallbackSubscriptionProtocolFactory = new FallbackSubscriptionProtocolFactory(new GraphQLSubscriptionMapper(graphQLObjectMapper), invocationInputFactory, graphQLInvoker);
        this.allSubscriptionProtocols = Stream.concat(this.subscriptionProtocolFactories.stream(), Stream.of(this.fallbackSubscriptionProtocolFactory)).map(SubscriptionProtocolFactory::getProtocol).collect(Collectors.toList());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void onOpen(final Session session, EndpointConfig endpointConfig) {
        WebSocketSubscriptionProtocolFactory subscriptionProtocolFactory = (WebSocketSubscriptionProtocolFactory)endpointConfig.getUserProperties().get(PROTOCOL_FACTORY_REQUEST_KEY);
        SubscriptionSession subscriptionSession = subscriptionProtocolFactory.createSession(session);
        Object object = this.cacheLock;
        synchronized (object) {
            if (this.isShuttingDown.get()) {
                throw new IllegalStateException("Server is shutting down!");
            }
            this.sessionSubscriptionCache.put(session, subscriptionSession.getSubscriptions());
        }
        subscriptionSession.getPublisher().subscribe((Subscriber)new WebSocketSendSubscriber(session));
        log.debug("Session opened: {}, {}", (Object)session.getId(), (Object)endpointConfig);
        final Consumer<String> consumer = subscriptionProtocolFactory.createConsumer(subscriptionSession);
        session.addMessageHandler((MessageHandler)new MessageHandler.Whole<String>(){

            public void onMessage(String text) {
                try {
                    consumer.accept(text);
                }
                catch (Exception t) {
                    log.error("Error executing websocket query for session: {}", (Object)session.getId(), (Object)t);
                    GraphQLWebsocketServlet.this.closeUnexpectedly(session, t);
                }
            }
        });
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void onClose(Session session, CloseReason closeReason) {
        SessionSubscriptions subscriptions;
        log.debug("Session closed: {}, {}", (Object)session.getId(), (Object)closeReason);
        Object object = this.cacheLock;
        synchronized (object) {
            subscriptions = this.sessionSubscriptionCache.remove(session);
        }
        if (subscriptions != null) {
            subscriptions.close();
        }
    }

    public void onError(Session session, Throwable thr) {
        if (thr instanceof EOFException) {
            log.warn("Session {} was killed abruptly without calling onClose. Cleaning up session", (Object)session.getId());
            this.onClose(session, ERROR_CLOSE_REASON);
        } else {
            log.error("Error in websocket session: {}", (Object)session.getId(), (Object)thr);
            this.closeUnexpectedly(session, thr);
        }
    }

    private void closeUnexpectedly(Session session, Throwable t) {
        try {
            session.close(ERROR_CLOSE_REASON);
        }
        catch (IOException e) {
            log.error("Error closing websocket session for session: {}", (Object)session.getId(), (Object)t);
        }
    }

    public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
        sec.getUserProperties().put(HANDSHAKE_REQUEST_KEY, request);
        List<String> protocol = (List<String>)request.getHeaders().get("Sec-WebSocket-Protocol");
        if (protocol == null) {
            protocol = Collections.emptyList();
        }
        SubscriptionProtocolFactory subscriptionProtocolFactory = this.getSubscriptionProtocolFactory(protocol);
        sec.getUserProperties().put(PROTOCOL_FACTORY_REQUEST_KEY, subscriptionProtocolFactory);
        if (request.getHeaders().get("Sec-WebSocket-Accept") != null) {
            response.getHeaders().put("Sec-WebSocket-Accept", this.allSubscriptionProtocols);
        }
        if (!protocol.isEmpty()) {
            response.getHeaders().put("Sec-WebSocket-Protocol", new ArrayList<String>(Arrays.asList(subscriptionProtocolFactory.getProtocol())));
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void beginShutDown() {
        Object object = this.cacheLock;
        synchronized (object) {
            this.isShuttingDown.set(true);
            HashMap<Session, SessionSubscriptions> copy = new HashMap<Session, SessionSubscriptions>(this.sessionSubscriptionCache);
            copy.forEach((session, wsSessionSubscriptions) -> {
                wsSessionSubscriptions.close();
                try {
                    session.close(SHUTDOWN_CLOSE_REASON);
                }
                catch (IOException e) {
                    log.error("Error closing websocket session!", (Throwable)e);
                }
            });
            copy.clear();
            if (!this.sessionSubscriptionCache.isEmpty()) {
                log.error("GraphQLWebsocketServlet did not shut down cleanly!");
                this.sessionSubscriptionCache.clear();
            }
        }
        this.isShutDown.set(true);
    }

    public boolean isShutDown() {
        return this.isShutDown.get();
    }

    private SubscriptionProtocolFactory getSubscriptionProtocolFactory(List<String> accept) {
        for (String protocol : accept) {
            for (SubscriptionProtocolFactory subscriptionProtocolFactory : this.subscriptionProtocolFactories) {
                if (!subscriptionProtocolFactory.getProtocol().equals(protocol)) continue;
                return subscriptionProtocolFactory;
            }
        }
        return this.fallbackSubscriptionProtocolFactory;
    }

    public int getSessionCount() {
        return this.sessionSubscriptionCache.size();
    }

    public int getSubscriptionCount() {
        return this.sessionSubscriptionCache.values().stream().mapToInt(SessionSubscriptions::getSubscriptionCount).sum();
    }
}

