/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.server;

import static org.mule.runtime.api.i18n.I18nMessageFactory.createStaticMessage;
import static org.mule.runtime.api.util.MuleSystemProperties.SYSTEM_PROPERTY_PREFIX;
import static org.mule.service.http.netty.impl.util.HttpLoggingHandler.hexDump;

import static java.lang.Integer.parseInt;
import static java.lang.String.format;
import static java.lang.System.getProperty;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_MAX_CHUNK_SIZE;
import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_MAX_HEADER_SIZE;
import static io.netty.handler.flush.FlushConsolidationHandler.DEFAULT_EXPLICIT_FLUSH_AFTER_FLUSHES;
import static io.netty.handler.ssl.ApplicationProtocolNames.HTTP_1_1;
import static io.netty.handler.ssl.ApplicationProtocolNames.HTTP_2;

import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.http.api.Http1ProtocolConfig;
import org.mule.runtime.http.api.Http2ProtocolConfig;
import org.mule.runtime.http.api.server.HttpServerConfiguration;
import org.mule.service.http.netty.impl.server.util.HttpListenerRegistry;
import org.mule.service.http.netty.impl.util.HttpLoggingHandler;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;

import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.HttpServerUpgradeHandler;
import io.netty.handler.codec.http2.CleartextHttp2ServerUpgradeHandler;
import io.netty.handler.codec.http2.Http2FrameCodecBuilder;
import io.netty.handler.codec.http2.Http2MultiplexHandler;
import io.netty.handler.flush.FlushConsolidationHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.timeout.IdleStateHandler;

/**
 * Implementation of {@link ChannelInitializer} that handles SSL if a non-null {@link SslContext} is received, and after that adds
 * the needed decoders ending by delegate to the passed {@link HttpListenerRegistry}.
 */
public class AcceptedConnectionChannelInitializer extends ChannelInitializer<SocketChannel> {

  // Defines the maximum size in bytes accepted for the http request header section (request line + headers)
  public static final String MAXIMUM_HEADER_SECTION_SIZE_PROPERTY_KEY = SYSTEM_PROPERTY_PREFIX + "http.headerSectionSize";
  public static final String FLUSH_CONSOLIDATION_HANDLER_NAME = "flushConsolidationHandler";

  private final SslContext sslContext;
  private final HttpListenerRegistry httpListenerRegistry;
  private final boolean usePersistentConnections;
  private final long connectionIdleTimeout;
  private final long readTimeout;

  private final int maxInitialLineLength;
  private final int maxHeaderSize;

  private final ConnectionsCounterHandler connectionsCountHandler;
  private final ExecutorService ioExecutor;

  private final Http1ProtocolConfig http1Config;
  private final Http2ProtocolConfig http2Config;

  public AcceptedConnectionChannelInitializer(HttpListenerRegistry httpListenerRegistry,
                                              HttpServerConfiguration configuration,
                                              SslContext sslContext,
                                              ExecutorService ioExecutor) {
    this(httpListenerRegistry, configuration.isUsePersistentConnections(), configuration.getConnectionIdleTimeout(),
         configuration.getReadTimeout(), sslContext, ioExecutor, configuration.getHttp1Config(), configuration.getHttp2Config());
  }

  public AcceptedConnectionChannelInitializer(HttpListenerRegistry httpListenerRegistry,
                                              boolean usePersistentConnections,
                                              int connectionIdleTimeout,
                                              long readTimeout,
                                              SslContext sslContext,
                                              ExecutorService ioExecutor) {
    this(httpListenerRegistry, usePersistentConnections, connectionIdleTimeout, readTimeout, sslContext,
         retrieveMaximumHeaderSectionSize(), ioExecutor, new Http1ProtocolConfig(true), new Http2ProtocolConfig(true));
  }

  public AcceptedConnectionChannelInitializer(HttpListenerRegistry httpListenerRegistry,
                                              boolean usePersistentConnections,
                                              int connectionIdleTimeout,
                                              long readTimeout,
                                              SslContext sslContext,
                                              ExecutorService ioExecutor,
                                              Http1ProtocolConfig http1Config,
                                              Http2ProtocolConfig http2Config) {
    this(httpListenerRegistry, usePersistentConnections, connectionIdleTimeout, readTimeout, sslContext,
         retrieveMaximumHeaderSectionSize(), ioExecutor, http1Config, http2Config);
  }

  public AcceptedConnectionChannelInitializer(HttpListenerRegistry httpListenerRegistry,
                                              boolean usePersistentConnections,
                                              int connectionIdleTimeout,
                                              long readTimeout,
                                              SslContext sslContext,
                                              int maxHeaderSectionSize,
                                              ExecutorService ioExecutor) {
    this(httpListenerRegistry, usePersistentConnections, connectionIdleTimeout, readTimeout, sslContext, maxHeaderSectionSize,
         ioExecutor, new Http1ProtocolConfig(true), new Http2ProtocolConfig(true));
  }

  public AcceptedConnectionChannelInitializer(HttpListenerRegistry httpListenerRegistry,
                                              boolean usePersistentConnections,
                                              int connectionIdleTimeout,
                                              long readTimeout,
                                              SslContext sslContext,
                                              int maxHeaderSectionSize,
                                              ExecutorService ioExecutor,
                                              Http1ProtocolConfig http1Config,
                                              Http2ProtocolConfig http2Config) {
    this.httpListenerRegistry = httpListenerRegistry;
    this.usePersistentConnections = usePersistentConnections;
    this.connectionIdleTimeout = connectionIdleTimeout;
    this.readTimeout = readTimeout;
    this.sslContext = sslContext;

    this.maxInitialLineLength = maxHeaderSectionSize;
    this.maxHeaderSize = maxHeaderSectionSize;

    this.connectionsCountHandler = new ConnectionsCounterHandler();
    this.ioExecutor = ioExecutor;

    this.http1Config = http1Config;
    this.http2Config = http2Config;
  }

  @Override
  protected void initChannel(SocketChannel socketChannel) {
    socketChannel.pipeline().addFirst(FLUSH_CONSOLIDATION_HANDLER_NAME,
                                      new FlushConsolidationHandler(DEFAULT_EXPLICIT_FLUSH_AFTER_FLUSHES, true));
    socketChannel.pipeline().addFirst(connectionsCountHandler);

    configureTimeouts(socketChannel);

    if (null != sslContext) {
      // The decision between http versions is done by ALPN.
      configureWithSslAndALPN(socketChannel);
      return;
    }

    if (!http1Config.isEnabled() && !http2Config.isEnabled()) {
      throw new IllegalStateException("Both HTTP/1 and HTTP/2 protocols are disabled");
    }

    if (http1Config.isEnabled() && http2Config.isEnabled()) {
      configureHttp2CleartextWithFallbackToHttp1(socketChannel);
      return;
    }

    if (http2Config.isEnabled()) {
      // Only h2c (aka prior-knowledge)
      configureHttp2(socketChannel.pipeline(), null);
      return;
    }

    // Only HTTP/1
    configureHttp1(socketChannel.pipeline(), null);
  }

  private void configureHttp2CleartextWithFallbackToHttp1(SocketChannel socketChannel) {
    socketChannel.pipeline().addLast("HTTP/2 Logging Handler", hexDump());
    socketChannel.pipeline().addLast("Cleartext HTTP/2 Server Upgrade Handler", createHttp2CleartextUpgradeHandler());

    // Fallback to HTTP/1, but without the codec since it's already in the upgrade handler above.
    configureHttp1PostCodec(socketChannel.pipeline(), null);
  }

  private void configureHttp1PostCodec(ChannelPipeline pipeline, SslHandler sslHandler) {
    pipeline.addLast("Expect Continue Handler", new MuleHttpServerExpectContinueHandler());
    pipeline.addLast("Keep Alive", new KeepAliveHandler(usePersistentConnections));
    pipeline.addLast("Forward Initializer", new ForwardingToListenerInitializer(httpListenerRegistry, sslHandler, ioExecutor));
  }

  private CleartextHttp2ServerUpgradeHandler createHttp2CleartextUpgradeHandler() {
    // This is just a normal HTTP/1 Codec that will parse the objects needed by the upgrade handlers.
    var sourceCodec = new HttpServerCodec(maxInitialLineLength, maxHeaderSize, DEFAULT_MAX_CHUNK_SIZE);

    // This factory will be invoked to instrument the pipeline when the Upgrade mechanism is used.
    var upgradeCodecFactory = new UpgradeToHttp2CleartextCodecFactory(httpListenerRegistry);

    // This handler will be used when the upgrade to h2c is requested by client (Upgrade: h2c header).
    var upgradeHandler = new HttpServerUpgradeHandler(sourceCodec, upgradeCodecFactory);

    // This handler does the following:
    // - It uses the first arg to parse the incoming bytes.
    // - If the incoming bytes start with the prior-knowledge preface, it just adds the third arg to the pipeline.
    // - Otherwise, it starts the upgrade mechanism, that will be controlled by the second arg.
    return new CleartextHttp2ServerUpgradeHandler(sourceCodec, upgradeHandler, new ChannelInitializer<SocketChannel>() {

      @Override
      protected void initChannel(SocketChannel ch) {
        // When the upgrade is done, the pipeline becomes h2c.
        configureHttp2(ch.pipeline(), null);
      }
    });
  }

  /**
   * Waits for all the in-flight connections to be closed, for the time passed as parameter.
   *
   * @param timeout  time to wait. Zero means don't wait.
   * @param timeUnit unit for the timeout parameter.
   */
  public void waitForConnectionsToBeClosed(Long timeout, TimeUnit timeUnit) {
    connectionsCountHandler.waitForConnectionsToBeClosed(timeout, timeUnit);
  }

  /**
   * Configure the pipeline for TLS NPN negotiation to HTTP/1.
   */
  private void configureWithSslAndALPN(Channel channel) {
    SslHandler sslHandler = sslContext.newHandler(channel.alloc());
    channel.pipeline()
        .addLast("SSL Handler", sslHandler)
        .addLast("Protocol Negotiation Handler",
                 new HttpWithAPNServerHandler(this::configureHttp1, this::configureHttp2, sslHandler, apnFallback()));
  }

  private String apnFallback() {
    return http1Config.isEnabled() ? HTTP_1_1 : HTTP_2;
  }

  private void configureTimeouts(SocketChannel socketChannel) {
    if (connectionIdleTimeout != -1 || readTimeout != -1) {
      socketChannel.pipeline().addLast("idleStateHandler",
                                       new IdleStateHandler(readTimeout, connectionIdleTimeout, connectionIdleTimeout,
                                                            MILLISECONDS));
    }
  }

  protected void configureHttp1(ChannelPipeline pipeline, SslHandler sslHandler) {
    pipeline.addLast("Logging Handler", HttpLoggingHandler.textual());
    pipeline.addLast("HTTP/1 Codec", new HttpServerCodec(maxInitialLineLength, maxHeaderSize, DEFAULT_MAX_CHUNK_SIZE));
    configureHttp1PostCodec(pipeline, sslHandler);
  }

  protected void configureHttp2(ChannelPipeline pipeline, SslHandler sslHandler) {
    pipeline.addLast("Logging Handler", HttpLoggingHandler.hexDump());
    pipeline.addLast("HTTP/2 Codec", Http2FrameCodecBuilder.forServer().build());
    pipeline.addLast("Multiplexer",
                     new Http2MultiplexHandler(new MultiplexerChannelInitializer(httpListenerRegistry, sslHandler)));
  }

  private static int retrieveMaximumHeaderSectionSize() {
    try {
      return parseInt(getProperty(MAXIMUM_HEADER_SECTION_SIZE_PROPERTY_KEY, String.valueOf(DEFAULT_MAX_HEADER_SIZE)));
    } catch (NumberFormatException e) {
      throw new MuleRuntimeException(createStaticMessage(format("Invalid value %s for %s configuration",
                                                                getProperty(MAXIMUM_HEADER_SECTION_SIZE_PROPERTY_KEY),
                                                                MAXIMUM_HEADER_SECTION_SIZE_PROPERTY_KEY)),
                                     e);
    }
  }
}
