/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.server;

import static org.mule.runtime.api.util.Preconditions.checkArgument;
import static org.mule.runtime.http.api.HttpConstants.Protocol.HTTP;
import static org.mule.runtime.http.api.HttpConstants.Protocol.HTTPS;
import static org.mule.runtime.http.api.server.MethodRequestMatcher.acceptAll;

import static java.lang.Math.min;
import static java.lang.Runtime.getRuntime;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;

import static org.slf4j.LoggerFactory.getLogger;

import org.mule.runtime.api.scheduler.Scheduler;
import org.mule.runtime.http.api.HttpConstants;
import org.mule.runtime.http.api.server.HttpServer;
import org.mule.runtime.http.api.server.MethodRequestMatcher;
import org.mule.runtime.http.api.server.PathAndMethodRequestMatcher;
import org.mule.runtime.http.api.server.RequestHandler;
import org.mule.runtime.http.api.server.RequestHandlerManager;
import org.mule.runtime.http.api.server.ServerAddress;
import org.mule.runtime.http.api.server.ws.WebSocketHandler;
import org.mule.runtime.http.api.server.ws.WebSocketHandlerManager;
import org.mule.service.http.netty.impl.server.util.DefaultServerAddress;
import org.mule.service.http.netty.impl.server.util.HttpListenerRegistry;
import org.mule.service.http.netty.impl.util.HttpLoggingHandler;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.Collection;
import java.util.function.Supplier;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.ssl.SslContext;
import org.slf4j.Logger;

public final class NettyHttpServer implements HttpServer {

  private static final Logger LOGGER = getLogger(NettyHttpServer.class);

  private static final int DEFAULT_SELECTORS_COUNT = min(getRuntime().availableProcessors(), 2);

  private static final int STOPPED = 0;
  private static final int STARTING = 1;
  private static final int STARTED = 2;
  private static final int STOPPING = 3;

  private int status = STOPPED;

  // Netty stuff:
  private SslContext sslContext;
  private EventLoopGroup acceptorGroup;
  private EventLoopGroup workerGroup;
  private Channel serverChannel;

  private ServerAddress serverAddress;
  private Runnable onDisposeCallback;
  private AcceptedConnectionChannelInitializer clientChannelHandler;
  private HttpListenerRegistry httpListenerRegistry;
  private WebSocketsHandlersRegistry webSocketsHandlersRegistry = new WebSocketsHandlersRegistry() {};
  private Scheduler selectorsScheduler;
  private int selectorsCount = DEFAULT_SELECTORS_COUNT;
  private Supplier<Long> shutdownTimeout;

  private NettyHttpServer() {}

  public static Builder builder() {
    return new Builder();
  }

  @Override
  public HttpServer start() throws IOException {
    status = STARTING;
    try {
      if (acceptorGroup == null) {
        acceptorGroup = new NioEventLoopGroup(1, selectorsScheduler);
      }
      if (workerGroup == null) {
        workerGroup = new NioEventLoopGroup(selectorsCount - 1, selectorsScheduler);
      }

      ServerBootstrap bootstrap = new ServerBootstrap();

      // TODO: With this method we can configure options such as SO_TIMEOUT, SO_KEEPALIVE, etc.
      bootstrap.option(ChannelOption.SO_BACKLOG, 1024);

      bootstrap.group(acceptorGroup, workerGroup)
          .channel(NioServerSocketChannel.class)
          .handler(new HttpLoggingHandler())
          .childHandler(clientChannelHandler);
      serverChannel = bootstrap.bind(serverAddress.getAddress(), serverAddress.getPort()).sync().channel();

      status = STARTED;
      LOGGER.info("HTTP Server is listening on address: {}", serverAddress);
      return this;
    } catch (InterruptedException e) {
      status = STOPPED;
      stop();
      throw new IOException(e);
    }
  }

  @Override
  public HttpServer stop() {
    status = STOPPING;
    if (serverChannel != null) {
      ChannelFuture channelFuture = serverChannel.close();
      clientChannelHandler.waitForConnectionsToBeClosed(shutdownTimeout.get(), MILLISECONDS);
      channelFuture.syncUninterruptibly();
      serverChannel = null; // Clear reference
    }
    if (acceptorGroup != null) {
      acceptorGroup.shutdownGracefully(0, 0, SECONDS).syncUninterruptibly();
      acceptorGroup = null;
    }
    if (workerGroup != null) {
      workerGroup.shutdownGracefully(0, 0, SECONDS).syncUninterruptibly();
      workerGroup = null;
    }
    status = STOPPED;
    return this;
  }


  @Override
  public void dispose() {
    if (!isStopped()) {
      stop();
    }
    if (onDisposeCallback != null) {
      onDisposeCallback.run();
    }
    if (selectorsScheduler != null) {
      selectorsScheduler.stop();
      selectorsScheduler = null;
    }
  }

  @Override
  public ServerAddress getServerAddress() {
    return serverAddress;
  }

  @Override
  public HttpConstants.Protocol getProtocol() {
    return sslContext == null ? HTTP : HTTPS;
  }

  @Override
  public boolean isStopping() {
    return STOPPING == status;
  }

  @Override
  public boolean isStopped() {
    return STOPPED == status;
  }

  @Override
  public RequestHandlerManager addRequestHandler(Collection<String> methods, String path, RequestHandler requestHandler) {
    return httpListenerRegistry.addRequestHandler(this, requestHandler, PathAndMethodRequestMatcher.builder()
        .methodRequestMatcher(MethodRequestMatcher.builder(methods).build())
        .path(path)
        .build());
  }

  @Override
  public RequestHandlerManager addRequestHandler(String path, RequestHandler requestHandler) {
    return httpListenerRegistry.addRequestHandler(this, requestHandler, PathAndMethodRequestMatcher.builder()
        .methodRequestMatcher(acceptAll())
        .path(path)
        .build());
  }

  @Override
  public WebSocketHandlerManager addWebSocketHandler(WebSocketHandler handler) {
    return webSocketsHandlersRegistry.addWebSocketHandler(handler);
  }

  public static class Builder {

    private final NettyHttpServer product;

    Builder() {
      this.product = new NettyHttpServer();
    }

    public HttpServer build() {
      checkArgument(product.serverAddress != null, "Server address can't be null");
      checkArgument(product.httpListenerRegistry != null, "Listener registry can't be null");
      checkArgument(product.shutdownTimeout != null, "Shutdown timeout Supplier can't be null");
      return product;
    }

    public Builder withSslContext(SslContext sslContext) {
      product.sslContext = sslContext;
      return this;
    }

    public Builder withServerAddress(InetSocketAddress socketAddress) {
      product.serverAddress = new DefaultServerAddress(socketAddress.getAddress(), socketAddress.getPort());
      return this;
    }

    public Builder withHttpListenerRegistry(HttpListenerRegistry httpListenerRegistry) {
      product.httpListenerRegistry = httpListenerRegistry;
      return this;
    }

    public Builder withWebSocketsHandlersRegistry(WebSocketsHandlersRegistry webSocketsHandlersRegistry) {
      product.webSocketsHandlersRegistry = webSocketsHandlersRegistry;
      return this;
    }

    public Builder doOnDispose(Runnable onDisposeCallback) {
      product.onDisposeCallback = onDisposeCallback;
      return this;
    }

    public Builder withClientChannelHandler(AcceptedConnectionChannelInitializer clientChannelHandler) {
      product.clientChannelHandler = clientChannelHandler;
      return this;
    }

    public Builder withSelectorsScheduler(Scheduler selectorsScheduler) {
      product.selectorsScheduler = selectorsScheduler;
      return this;
    }

    public Builder withSelectorsCount(int selectorsCount) {
      product.selectorsCount = selectorsCount;
      return this;
    }

    public Builder withShutdownTimeout(Supplier<Long> shutdownTimeout) {
      product.shutdownTimeout = shutdownTimeout;
      return this;
    }
  }
}
