/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.client;

import static org.mule.runtime.api.util.MuleSystemProperties.SYSTEM_PROPERTY_PREFIX;
import static org.mule.service.http.netty.impl.transport.NativeChannelTransportSpecifics.createLoopGroup;
import static org.mule.service.http.netty.impl.transport.NativeChannelTransportSpecifics.getDatagramChannelType;
import static org.mule.service.http.netty.impl.util.HttpLoggingHandler.textual;
import static org.mule.service.http.netty.impl.util.MuleToNettyUtils.addAllRequestHeaders;

import static java.lang.Boolean.getBoolean;
import static java.lang.Integer.getInteger;
import static java.lang.Math.min;
import static java.lang.Runtime.getRuntime;
import static java.lang.Thread.currentThread;

import static org.slf4j.LoggerFactory.getLogger;
import static reactor.netty.NettyPipeline.HttpCodec;
import static reactor.netty.NettyPipeline.SslReader;

import org.mule.runtime.api.scheduler.Scheduler;
import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.client.HttpRequestOptions;
import org.mule.runtime.http.api.client.auth.HttpAuthentication;
import org.mule.runtime.http.api.client.proxy.ProxyConfig;
import org.mule.runtime.http.api.client.ws.WebSocketCallback;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.runtime.http.api.sse.client.SseSource;
import org.mule.runtime.http.api.sse.client.SseSourceConfig;
import org.mule.runtime.http.api.tcp.TcpClientSocketProperties;
import org.mule.runtime.http.api.ws.WebSocket;
import org.mule.service.http.common.client.sse.DefaultSseSource;
import org.mule.service.http.common.client.sse.InternalClient;
import org.mule.service.http.common.client.sse.NoOpProgressiveBodyDataListener;
import org.mule.service.http.common.client.sse.ProgressiveBodyDataListener;
import org.mule.service.http.netty.impl.client.auth.AuthenticationEngine;
import org.mule.service.http.netty.impl.client.auth.AuthenticationHandler;
import org.mule.service.http.netty.impl.client.proxy.ProxyPipelineConfigurer;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.time.Duration;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpContentDecompressor;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.ssl.SslContext;
import io.netty.resolver.AddressResolverGroup;
import io.netty.resolver.dns.DefaultDnsCache;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.resolver.dns.RoundRobinDnsAddressResolverGroup;
import org.slf4j.Logger;
import org.slf4j.MDC;
import reactor.netty.resources.ConnectionProvider;
import reactor.netty.tcp.SslProvider;

/**
 * An HTTP client that allows you to send HTTP/1 and HTTP/2 requests to a server.
 */
public class NettyHttpClient implements HttpClient, InternalClient {

  private static final Logger LOGGER = getLogger(NettyHttpClient.class);
  private static final int DEFAULT_SELECTORS_COUNT = min(getRuntime().availableProcessors(), 2);
  private static final int CLIENT_DNS_EVENT_LOOP_COUNT = min(getRuntime().availableProcessors(), 2);
  private static final Duration PENDING_ACQUIRE_TIMEOUT = Duration.ofMillis(100);
  private static final int GRACEFUL_SHUTDOWN_WAIT = 5000;

  // Reactor netty does not have a way to configure infinite connections but this is an approximation
  private static final int MAX_CONNECTIONS_UNLIMITED = Integer.MAX_VALUE;
  private static final long EVICT_IN_BACKGROUND_AFTER = 10;
  private SslContext sslContext;
  private int connectionIdleTimeout;
  private boolean usePersistentConnections;
  private reactor.netty.http.client.HttpClient noProxyHttpClient;
  private ReactorNettyClient reactorNettyClient;
  private WebSocketsProvider webSocketsProvider = new WebSocketsProvider() {};

  private int maxConnections;

  private int selectorsCount = DEFAULT_SELECTORS_COUNT;
  private Supplier<Scheduler> selectorsSchedulerSupplier = () -> null;
  private Scheduler selectorsScheduler;
  private int dnsEventLoopCount = CLIENT_DNS_EVENT_LOOP_COUNT;
  private ScheduledExecutorService ioTasksScheduler;

  private ProxyConfig proxyConfig;

  private static final int MAX_NUM_HEADERS_DEFAULT = 100;
  private static final String MAX_CLIENT_REQUEST_HEADERS_KEY = SYSTEM_PROPERTY_PREFIX + "http.MAX_CLIENT_REQUEST_HEADERS";
  private static int MAX_CLIENT_REQUEST_HEADERS;

  private EventLoopGroup selectorsGroup;
  private EventLoopGroup dnsEventLoopGroup;
  private AddressResolverGroup<InetSocketAddress> resolverGroup;
  private boolean responseStreamingEnabled = false;
  private String name;
  private TcpClientSocketProperties tcpProperties = TcpClientSocketProperties.builder().build();

  private static final String DEFAULT_DECOMPRESS_KEY = "mule.http.client.decompress";
  private static boolean DEFAULT_DECOMPRESS;
  private boolean decompressionEnabled = getDefaultDecompression();

  static {
    refreshSystemProperties();
  }

  private NettyHttpClient() {}

  public static Builder builder() {
    return new Builder();
  }

  @Override
  public void start() {
    final int resultingMaxConnections = maxConnections > 0 ? maxConnections : MAX_CONNECTIONS_UNLIMITED;
    ConnectionProvider.Builder connectionProviderBuilder =
        ConnectionProvider.builder("CustomConnectionProvider")
            .maxConnections(resultingMaxConnections)
            .maxIdleTime(Duration.ofMillis(connectionIdleTimeout))
            .evictInBackground(Duration.ofMillis(connectionIdleTimeout > 0 ? EVICT_IN_BACKGROUND_AFTER : 0));

    if (maxConnections > 0) {
      connectionProviderBuilder
          .pendingAcquireMaxCount(maxConnections)
          .pendingAcquireTimeout(PENDING_ACQUIRE_TIMEOUT);
    }

    // Initialize selectorsGroup for the HTTP client
    this.selectorsScheduler = selectorsSchedulerSupplier.get();
    this.selectorsGroup = createLoopGroup(selectorsCount, selectorsScheduler);

    // Initialize resolver group
    resolverGroup = initializeResolverGroup();

    reactor.netty.http.client.HttpClient httpClient = reactor.netty.http.client.HttpClient
        .create(connectionProviderBuilder.build())
        .runOn(selectorsGroup)
        .doOnChannelInit((observer, channel, remoteAddress) -> {
          // Note: "first" in this pipeline means "the nearest to the server".
          var wireLogger = textual(name);
          if (channel.pipeline().names().contains(SslReader)) {
            channel.pipeline().addAfter(SslReader, "Client HTTP Logging", wireLogger);
          } else {
            channel.pipeline().addFirst("Client HTTP Logging", wireLogger);
          }

          channel.pipeline().addBefore(HttpCodec, "HttpLogging", textual(name));

          if (decompressionEnabled) {
            channel.pipeline().addAfter(HttpCodec, "decompress", new HttpContentDecompressor());
          }
          channel.pipeline().addAfter(HttpCodec, "removeContentLengthHandler", new RemoveContentLengthHandler());
          channel.pipeline().addAfter(HttpCodec, "100ContinueClientHandler", new ClientExpectContinueHandler(ioTasksScheduler));
          channel.pipeline().addAfter(HttpCodec, "RedirectMethodChangeHandler", new RedirectMethodChangeHandler());

          channel.pipeline().addLast(new HttpClientHandler(usePersistentConnections));
        })
        .followRedirect(true)
        .resolver(resolverGroup);

    httpClient = configureTcpOptions(httpClient);

    if (sslContext != null) {
      httpClient = httpClient.secure(SslProvider.builder().sslContext(sslContext).build());
    }

    this.noProxyHttpClient = httpClient;
    this.reactorNettyClient =
        new ReactorNettyClient(name, configureProxy(proxyConfig), ioTasksScheduler, responseStreamingEnabled);
  }

  private reactor.netty.http.client.HttpClient configureTcpOptions(reactor.netty.http.client.HttpClient httpClient) {
    if (tcpProperties.getClientTimeout() != null) {
      httpClient = httpClient.option(ChannelOption.SO_TIMEOUT, tcpProperties.getClientTimeout());
    }
    if (tcpProperties.getSendBufferSize() != null) {
      httpClient = httpClient.option(ChannelOption.SO_SNDBUF, tcpProperties.getSendBufferSize());
    }
    if (tcpProperties.getReceiveBufferSize() != null) {
      httpClient = httpClient.option(ChannelOption.SO_RCVBUF, tcpProperties.getReceiveBufferSize());
    }
    if (tcpProperties.getLinger() != null) {
      httpClient = httpClient.option(ChannelOption.SO_LINGER, tcpProperties.getLinger());
    }
    return httpClient
        .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, tcpProperties.getConnectionTimeout())
        .option(ChannelOption.SO_KEEPALIVE, tcpProperties.getKeepAlive())
        .option(ChannelOption.TCP_NODELAY, tcpProperties.getSendTcpNoDelay());
  }

  private reactor.netty.http.client.HttpClient configureProxy(ProxyConfig proxyCfg) {
    return this.noProxyHttpClient.doOnChannelInit((observer, channel, remoteAddress) -> {
      // For backwards compatibility, we use a blind tunnel proxy iif using SSL.
      boolean useTunnelingProxy = sslContext != null;
      new ProxyPipelineConfigurer(proxyCfg).configurePipeline(channel.pipeline(), remoteAddress, useTunnelingProxy);
    });
  }

  private AddressResolverGroup<InetSocketAddress> initializeResolverGroup() {
    if (proxyConfig == null || sslContext == null) {
      LOGGER.debug("Initializing DNS resolver for direct connections");

      // Initialize EventLoopGroup for DNS resolution
      this.dnsEventLoopGroup = createLoopGroup(dnsEventLoopCount, ioTasksScheduler);

      DnsNameResolverBuilder dnsNameResolverBuilder = new DnsNameResolverBuilder(dnsEventLoopGroup.next())
          .datagramChannelType(getDatagramChannelType())
          .optResourceEnabled(true)
          .resolveCache(new DefaultDnsCache());

      return new RoundRobinDnsAddressResolverGroup(dnsNameResolverBuilder);
    } else {
      LOGGER.debug("Proxy is configured with SSL, so the DNS resolution should be done by the proxy. Using NOOP resolver");
      // Use InetNoopAddressResolverGroup to bypass DNS resolution
      return InetNoopAddressResolverGroup.INSTANCE;
    }
  }

  public AddressResolverGroup<InetSocketAddress> getResolverGroup() {
    return this.resolverGroup;
  }

  @Override
  public void stop() {
    if (selectorsGroup != null) {
      // TODO(W-19756817): implement proper fix
      if (!selectorsGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).awaitUninterruptibly(GRACEFUL_SHUTDOWN_WAIT)) {
        LOGGER.warn("Selector EventLoop Group hasn't finished gracefully shutting down after timeout");
      }
      selectorsGroup = null;
    }
    if (selectorsScheduler != null) {
      selectorsScheduler.stop();
      selectorsScheduler = null;
    }
    if (dnsEventLoopGroup != null) {
      // TODO(W-19756817): implement proper fix
      if (!dnsEventLoopGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).awaitUninterruptibly(GRACEFUL_SHUTDOWN_WAIT)) {
        LOGGER.warn("DNS EventLoop Group hasn't finished gracefully shutting down after timeout");
      }
      dnsEventLoopGroup = null;
    }
    resolverGroup = null;
  }

  @Override
  public HttpResponse send(HttpRequest request, HttpRequestOptions options) throws IOException {
    try {
      return sendAsync(request, options).get();
    } catch (InterruptedException e) {
      currentThread().interrupt();
      throw new IOException(e);
    } catch (ExecutionException e) {
      throw new IOException(e.getCause());
    }
  }

  @Override
  public CompletableFuture<HttpResponse> sendAsync(HttpRequest request, HttpRequestOptions options) {
    try {
      return doSendAsync(request, options, new NoOpProgressiveBodyDataListener());
    } catch (Throwable t) {
      CompletableFuture<HttpResponse> errorFuture = new CompletableFuture<>();
      errorFuture.completeExceptionally(t);
      return errorFuture;
    }
  }

  @Override
  public SseSource sseSource(SseSourceConfig config) {
    if (!responseStreamingEnabled) {
      throw new IllegalStateException("SSE source requires streaming enabled for client '%s'".formatted(name));
    }
    return new DefaultSseSource(config, this, ioTasksScheduler);
  }

  @Override
  public CompletableFuture<WebSocket> openWebSocket(HttpRequest request, String socketId, WebSocketCallback callback) {
    return webSocketsProvider.openWebSocket(request, socketId, callback, sslContext);
  }

  @Override
  public CompletableFuture<WebSocket> openWebSocket(HttpRequest request, HttpRequestOptions requestOptions, String socketId,
                                                    WebSocketCallback callback) {
    return webSocketsProvider.openWebSocket(request, requestOptions, socketId, callback, sslContext);
  }

  @Override
  public CompletableFuture<HttpResponse> doSendAsync(HttpRequest request, HttpRequestOptions options,
                                                     ProgressiveBodyDataListener dataListener) {
    var reactorNettyClient = options.getProxyConfig()
        .map(proxyCfg -> new ReactorNettyClient(name, configureProxy(proxyCfg), ioTasksScheduler, responseStreamingEnabled))
        .orElse(this.reactorNettyClient);

    final Optional<HttpAuthentication> authentication = options.getAuthentication();
    final AuthenticationEngine authHeadersProvider = authentication
        .map(httpAuthentication -> new AuthenticationEngine(httpAuthentication, request.getUri(), request.getMethod()))
        .orElse(null);

    reactorNettyClient.prepareContentForRepeatability(request.getEntity());

    CompletableFuture<HttpResponse> result = new CompletableFuture<>();
    HttpHeaders headers = constructHeaders(options, request, authHeadersProvider);
    reactorNettyClient.sendAsyncRequest(request, options, headers,
                                        (response, content) -> {
                                          reactorNettyClient.rewindStreamContent(request.getEntity());

                                          AuthenticationHandler authenticationHandler =
                                              new AuthenticationHandler(reactorNettyClient, authHeadersProvider);
                                          if (authentication.isPresent()
                                              && authenticationHandler.needsAuth(response, options)) {
                                            return authenticationHandler.doHandle(request, options,
                                                                                  response, result, dataListener);
                                          } else {
                                            return reactorNettyClient.receiveContent(response, content, result, dataListener);
                                          }
                                        }, result)
        .subscribe();
    CompletableFuture<HttpResponse> returnableFuture = new CompletableFuture<>();
    Map<String, String> propagatedMdc = MDC.getCopyOfContextMap();
    result.whenComplete((response, exception) -> {
      Map<String, String> overriddenMDC = MDC.getCopyOfContextMap();
      MDC.setContextMap(propagatedMdc);
      if (exception != null) {
        returnableFuture.completeExceptionally(exception);
      } else {
        returnableFuture.complete(response);
      }
      MDC.setContextMap(overriddenMDC);
    });
    return returnableFuture;
  }

  @Override
  public String getName() {
    return name;
  }

  private HttpHeaders constructHeaders(HttpRequestOptions options, HttpRequest request,
                                       AuthenticationEngine authEngine) {
    HttpHeaders headers = new DefaultHttpHeaders();
    addAllRequestHeaders(request, headers, request.getUri());

    checkMaxRequestHeadersLimit(headers);

    if (options.getAuthentication().isPresent()) {
      headers.add(authEngine.getAuthHeaders(headers));
    }
    return headers;
  }

  private static void checkMaxRequestHeadersLimit(HttpHeaders headers) {
    if (headers.size() > MAX_CLIENT_REQUEST_HEADERS) {
      LOGGER.warn("Exceeded max client request headers limit: {}. Current header count (including default headers): {}",
                  MAX_CLIENT_REQUEST_HEADERS,
                  headers.entries().size());
      throw new IllegalArgumentException("Exceeded max client request headers limit: " + MAX_CLIENT_REQUEST_HEADERS);
    }
  }

  public static void refreshSystemProperties() {
    MAX_CLIENT_REQUEST_HEADERS = getInteger(MAX_CLIENT_REQUEST_HEADERS_KEY, MAX_NUM_HEADERS_DEFAULT);
    DEFAULT_DECOMPRESS = getBoolean(DEFAULT_DECOMPRESS_KEY);
  }

  public static int getMaxClientRequestHeaders() {
    return MAX_CLIENT_REQUEST_HEADERS;
  }

  private static boolean getDefaultDecompression() {
    return DEFAULT_DECOMPRESS;
  }

  public static class Builder {

    private final NettyHttpClient product;

    private Builder() {
      product = new NettyHttpClient();
    }

    public NettyHttpClient build() {
      return product;
    }

    public Builder withName(String name) {
      product.name = name;
      return this;
    }

    public Builder withSslContext(SslContext sslContext) {
      product.sslContext = sslContext;
      return this;
    }

    public Builder withConnectionIdleTimeout(int connectionIdleTimeout) {
      product.connectionIdleTimeout = connectionIdleTimeout;
      return this;
    }

    public Builder withUsingPersistentConnections(boolean usePersistentConnections) {
      product.usePersistentConnections = usePersistentConnections;
      return this;
    }

    public Builder withMaxConnections(int maxConnections) {
      int resultingMaxConnections = maxConnections > 0 ? maxConnections : MAX_CONNECTIONS_UNLIMITED;
      product.maxConnections = resultingMaxConnections;
      return this;
    }

    public Builder withWebSocketsProvider(WebSocketsProvider webSocketsProvider) {
      product.webSocketsProvider = webSocketsProvider;
      return this;
    }

    public Builder withSelectorsScheduler(Supplier<Scheduler> selectorsScheduler) {
      product.selectorsSchedulerSupplier = selectorsScheduler;
      return this;
    }

    public Builder withSelectorsCount(int selectorsCount) {
      product.selectorsCount = selectorsCount;
      return this;
    }

    public Builder withProxyConfig(ProxyConfig proxyConfig) {
      product.proxyConfig = proxyConfig;
      return this;
    }

    public Builder withIOTasksScheduler(ScheduledExecutorService ioTasksScheduler) {
      product.ioTasksScheduler = ioTasksScheduler;
      return this;
    }

    public Builder withDnsEventLoopCount(int dnsEventLoopCount) {
      product.dnsEventLoopCount = dnsEventLoopCount;
      return this;
    }

    public Builder withResponseStreamingEnabled(boolean responseStreamingEnabled) {
      product.responseStreamingEnabled = responseStreamingEnabled;
      return this;
    }

    public Builder withTcpProperties(TcpClientSocketProperties tcpProperties) {
      product.tcpProperties = tcpProperties;
      return this;
    }

    public Builder withDecompressionEnabled(Boolean compressionEnabled) {
      if (compressionEnabled != null) {
        product.decompressionEnabled = compressionEnabled;
      }
      return this;
    }
  }
}
