/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.client;

import static org.mule.runtime.api.util.MuleSystemProperties.SYSTEM_PROPERTY_PREFIX;
import static org.mule.service.http.netty.impl.util.MuleToNettyUtils.addAllRequestHeaders;

import static java.lang.Integer.getInteger;
import static java.lang.Math.min;
import static java.lang.Runtime.getRuntime;
import static java.lang.Thread.currentThread;

import static org.slf4j.LoggerFactory.getLogger;
import static reactor.netty.NettyPipeline.HttpCodec;

import org.mule.runtime.api.scheduler.Scheduler;
import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.client.HttpRequestOptions;
import org.mule.runtime.http.api.client.auth.HttpAuthentication;
import org.mule.runtime.http.api.client.proxy.ProxyConfig;
import org.mule.runtime.http.api.client.ws.WebSocketCallback;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.runtime.http.api.ws.WebSocket;
import org.mule.service.http.netty.impl.client.auth.AuthenticationEngine;
import org.mule.service.http.netty.impl.client.auth.AuthenticationHandler;
import org.mule.service.http.netty.impl.client.proxy.ProxyPipelineConfigurer;
import org.mule.service.http.netty.impl.util.HttpLoggingHandler;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import io.netty.channel.ChannelOption;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioDatagramChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.ssl.SslContext;
import io.netty.resolver.AddressResolverGroup;
import io.netty.resolver.dns.DefaultDnsCache;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.resolver.dns.RoundRobinDnsAddressResolverGroup;
import org.slf4j.Logger;
import reactor.netty.resources.ConnectionProvider;
import reactor.netty.tcp.SslProvider;

/**
 * An HTTP client that allows you to send HTTP/1 and HTTP/2 requests to a server.
 */
// TODO: Should include {@link io.netty.handler.codec.http2.Http2ClientUpgradeCodec} if the HTTP/2 server you are hitting doesn't
// support h2c/prior knowledge.
public class NettyHttpClient implements HttpClient {

  private static final Logger LOGGER = getLogger(NettyHttpClient.class);
  private static final int DEFAULT_SELECTORS_COUNT = min(getRuntime().availableProcessors(), 2);
  private static final int CLIENT_DNS_EVENT_LOOP_COUNT = min(getRuntime().availableProcessors(), 2);

  // Reactor netty does not have a way to configure infinite connections but this is an approximation
  private static final int MAX_CONNECTIONS_UNLIMITED = Integer.MAX_VALUE;
  private static final long EVICT_IN_BACKGROUND_AFTER = 10;
  private final AtomicInteger totalConnections = new AtomicInteger();
  private SslContext sslContext;
  private int connectionIdleTimeout;
  private boolean usePersistentConnections;
  private ReactorNettyClient reactorNettyClient;
  private WebSocketsProvider webSocketsProvider = new WebSocketsProvider() {};

  private int maxConnections;
  private Scheduler selectorsScheduler;
  private Scheduler dnsEventLoopScheduler;
  private int selectorsCount = DEFAULT_SELECTORS_COUNT;
  private int dnsEventLoopCount = CLIENT_DNS_EVENT_LOOP_COUNT;
  private ProxyConfig proxyConfig;
  private static final int MAX_NUM_HEADERS_DEFAULT = 100;
  private static final String MAX_CLIENT_REQUEST_HEADERS_KEY = SYSTEM_PROPERTY_PREFIX + "http.MAX_CLIENT_REQUEST_HEADERS";
  private static int MAX_CLIENT_REQUEST_HEADERS = getInteger(MAX_CLIENT_REQUEST_HEADERS_KEY, MAX_NUM_HEADERS_DEFAULT);
  private NioEventLoopGroup selectorsGroup;
  private NioEventLoopGroup dnsEventLoopGroup;
  private AddressResolverGroup<InetSocketAddress> resolverGroup;

  private NettyHttpClient() {}

  public static Builder builder() {
    return new Builder();
  }

  @Override
  public void start() {
    final int resultingMaxConnections = maxConnections > 0 ? maxConnections : MAX_CONNECTIONS_UNLIMITED;
    ConnectionProvider.Builder connectionProviderBuilder =
        ConnectionProvider.builder("CustomConnectionProvider")
            .maxConnections(resultingMaxConnections)
            .maxIdleTime(Duration.ofMillis(connectionIdleTimeout))
            .evictInBackground(Duration.ofMillis(connectionIdleTimeout > 0 ? EVICT_IN_BACKGROUND_AFTER : 0));

    // Initialize selectorsGroup for the HTTP client
    this.selectorsGroup = new NioEventLoopGroup(selectorsCount, selectorsScheduler);

    // Initialize resolver group
    resolverGroup = initializeResolverGroup();

    reactor.netty.http.client.HttpClient httpClient = reactor.netty.http.client.HttpClient
        .create(connectionProviderBuilder.build())
        .runOn(selectorsGroup)
        .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectionIdleTimeout)
        .doOnConnect(config -> {
          if (totalConnections.get() + 1 > resultingMaxConnections) {
            throw new RuntimeException("Connection limit exceeded, cannot process request");
          }
        })
        .doOnConnected(connection -> totalConnections.incrementAndGet())
        .doOnChannelInit((observer, channel, remoteAddress) -> {
          // "First" in this pipeline means "the nearest to the server", so "before HttpCodec" we see bytes, and "after HttpCodec"
          // we see HttpMessages.

          // For backwards compatibility we use a blind tunnel proxy iif using SSL.
          boolean useTunnelingProxy = sslContext != null;
          new ProxyPipelineConfigurer(proxyConfig).configurePipeline(channel.pipeline(), remoteAddress, useTunnelingProxy);

          channel.pipeline().addBefore(HttpCodec, "HttpLogging", new HttpLoggingHandler());

          channel.pipeline().addAfter(HttpCodec, "removeContentLengthHandler", new RemoveContentLengthHandler());
          channel.pipeline().addAfter(HttpCodec, "100ContinueClientHandler", new ClientExpectContinueHandler());

          channel.pipeline().addLast(new HttpClientHandler(usePersistentConnections));
        })
        .followRedirect(true)
        .resolver(resolverGroup);

    if (sslContext != null) {
      httpClient = httpClient.secure(SslProvider.builder().sslContext(sslContext).build());
    }

    this.reactorNettyClient = new ReactorNettyClient(httpClient);
  }

  private AddressResolverGroup<InetSocketAddress> initializeResolverGroup() {
    if (proxyConfig == null) {
      LOGGER.info("Initializing DNS resolver for direct connections.");

      // Initialize EventLoopGroup for DNS resolution
      this.dnsEventLoopGroup = new NioEventLoopGroup(dnsEventLoopCount, dnsEventLoopScheduler);

      DnsNameResolverBuilder dnsNameResolverBuilder = new DnsNameResolverBuilder(dnsEventLoopGroup.next())
          .datagramChannelType(NioDatagramChannel.class)
          .optResourceEnabled(true)
          .resolveCache(new DefaultDnsCache());

      return new RoundRobinDnsAddressResolverGroup(dnsNameResolverBuilder);
    } else {
      LOGGER.info("Proxy is configured. Using InetNoopAddressResolverGroup.");
      // Use InetNoopAddressResolverGroup to bypass DNS resolution
      return InetNoopAddressResolverGroup.INSTANCE;
    }
  }

  public AddressResolverGroup<InetSocketAddress> getResolverGroup() {
    return this.resolverGroup;
  }

  @Override
  public void stop() {
    if (selectorsGroup != null) {
      selectorsGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).syncUninterruptibly();
    }
    if (dnsEventLoopGroup != null) {
      dnsEventLoopGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).syncUninterruptibly();
      dnsEventLoopGroup = null;
    }
    resolverGroup = null;
  }

  @Override
  public HttpResponse send(HttpRequest request, HttpRequestOptions options) throws IOException {
    try {
      return sendAsync(request, options).get();
    } catch (InterruptedException e) {
      currentThread().interrupt();
      throw new IOException(e);
    } catch (ExecutionException e) {
      throw new IOException(e.getCause());
    }
  }

  @Override
  public CompletableFuture<HttpResponse> sendAsync(HttpRequest request, HttpRequestOptions options) {
    try {
      return doSendAsync(request, options);
    } catch (Throwable t) {
      CompletableFuture<HttpResponse> errorFuture = new CompletableFuture<>();
      errorFuture.completeExceptionally(t);
      return errorFuture;
    }
  }

  @Override
  public CompletableFuture<WebSocket> openWebSocket(HttpRequest request, String socketId, WebSocketCallback callback) {
    return webSocketsProvider.openWebSocket(request, socketId, callback, sslContext);
  }

  @Override
  public CompletableFuture<WebSocket> openWebSocket(HttpRequest request, HttpRequestOptions requestOptions, String socketId,
                                                    WebSocketCallback callback) {
    return webSocketsProvider.openWebSocket(request, requestOptions, socketId, callback, sslContext);
  }

  private CompletableFuture<HttpResponse> doSendAsync(HttpRequest request, HttpRequestOptions options) {
    final Optional<HttpAuthentication> authentication = options.getAuthentication();
    final AuthenticationEngine authHeadersProvider = authentication
        .map(httpAuthentication -> new AuthenticationEngine(httpAuthentication, request.getUri(), request.getMethod()))
        .orElse(null);

    CompletableFuture<HttpResponse> result = new CompletableFuture<>();
    HttpHeaders headers = constructHeaders(options, request, authHeadersProvider);
    reactorNettyClient.sendAsyncRequest(request, options, headers,
                                        (response, content) -> {
                                          AuthenticationHandler authenticationHandler =
                                              new AuthenticationHandler(reactorNettyClient, authHeadersProvider);
                                          if (authentication.isPresent()
                                              && authenticationHandler.needsAuth(response, options)) {
                                            return authenticationHandler.doHandle(request, options,
                                                                                  response, result);
                                          } else {
                                            return reactorNettyClient.receiveContent(response, content, result);
                                          }
                                        }, result)
        .subscribe();
    return result;
  }

  private HttpHeaders constructHeaders(HttpRequestOptions options, HttpRequest request,
                                       AuthenticationEngine authEngine) {
    HttpHeaders headers = new DefaultHttpHeaders();
    addAllRequestHeaders(request, headers, request.getUri());

    checkMaxRequestHeadersLimit(headers);

    if (options.getAuthentication().isPresent()) {
      headers.add(authEngine.getAuthHeaders(headers));
    }
    return headers;
  }

  private static void checkMaxRequestHeadersLimit(HttpHeaders headers) {
    if (headers.size() > MAX_CLIENT_REQUEST_HEADERS) {
      LOGGER.warn("Exceeded max client request headers limit: {}. Current header count (including default headers): {}",
                  MAX_CLIENT_REQUEST_HEADERS,
                  headers.entries().size());
      throw new IllegalArgumentException("Exceeded max client request headers limit: " + MAX_CLIENT_REQUEST_HEADERS);
    }
  }

  public static void refreshMaxClientRequestHeaders() {
    MAX_CLIENT_REQUEST_HEADERS = getInteger(MAX_CLIENT_REQUEST_HEADERS_KEY, MAX_NUM_HEADERS_DEFAULT);
  }

  public static int getMaxClientRequestHeaders() {
    return MAX_CLIENT_REQUEST_HEADERS;
  }

  public static class Builder {

    private final NettyHttpClient product;

    private Builder() {
      product = new NettyHttpClient();
    }

    public NettyHttpClient build() {
      return product;
    }

    public Builder withSslContext(SslContext sslContext) {
      product.sslContext = sslContext;
      return this;
    }

    public Builder withConnectionIdleTimeout(int connectionIdleTimeout) {
      product.connectionIdleTimeout = connectionIdleTimeout;
      return this;
    }

    public Builder withUsingPersistentConnections(boolean usePersistentConnections) {
      product.usePersistentConnections = usePersistentConnections;
      return this;
    }

    public Builder withMaxConnections(int maxConnections) {
      int resultingMaxConnections = maxConnections > 0 ? maxConnections : MAX_CONNECTIONS_UNLIMITED;
      product.maxConnections = resultingMaxConnections;
      return this;
    }

    public Builder withWebSocketsProvider(WebSocketsProvider webSocketsProvider) {
      product.webSocketsProvider = webSocketsProvider;
      return this;
    }

    public Builder withSelectorsScheduler(Scheduler selectorsScheduler) {
      product.selectorsScheduler = selectorsScheduler;
      return this;
    }

    public Builder withSelectorsCount(int selectorsCount) {
      product.selectorsCount = selectorsCount;
      return this;
    }

    public Builder withProxyConfig(ProxyConfig proxyConfig) {
      product.proxyConfig = proxyConfig;
      return this;
    }

    public Builder withDnsEventLoopScheduler(Scheduler dnsEventLoopScheduler) {
      product.dnsEventLoopScheduler = dnsEventLoopScheduler;
      return this;
    }

    public Builder withDnsEventLoopCount(int dnsEventLoopCount) {
      product.dnsEventLoopCount = dnsEventLoopCount;
      return this;
    }
  }
}
