/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.test.common.http2;

import static org.mule.service.http.test.netty.AllureConstants.HTTP_2;

import static java.lang.Thread.currentThread;
import static java.util.Collections.newSetFromMap;
import static java.util.concurrent.CompletableFuture.allOf;
import static java.util.concurrent.TimeUnit.SECONDS;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;

import org.mule.runtime.http.api.Http1ProtocolConfig;
import org.mule.runtime.http.api.Http2ProtocolConfig;
import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.client.HttpClientConfiguration;
import org.mule.runtime.http.api.client.HttpRequestOptions;
import org.mule.runtime.http.api.domain.entity.ByteArrayHttpEntity;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.runtime.http.api.server.HttpServer;
import org.mule.runtime.http.api.server.HttpServerConfiguration;
import org.mule.service.http.test.common.AbstractHttpServiceTestCase;
import org.mule.tck.junit5.DynamicPort;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import io.qameta.allure.Description;
import io.qameta.allure.Feature;
import io.qameta.allure.Issue;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Tests HTTP/2 multiplexing fix to prevent connection explosion.
 */
@Feature(HTTP_2)
@Issue("W-19860076")
class Http2ReverseProxyMultiplexingTestCase extends AbstractHttpServiceTestCase {

  private static final Logger LOGGER = LoggerFactory.getLogger(Http2ReverseProxyMultiplexingTestCase.class);

  private static final String BACKEND_ENDPOINT = "/test";

  private static final int TOTAL_REQUESTS = 1000;

  // For unlimited connections: effectiveMaxConnections = 500, minConnections = max(1, 500/50) = 10
  private static final int HTTP2_UNLIMITED_EFFECTIVE_MAX = 500;

  // Customer's maxConnections configuration
  private static final int CUSTOMER_MAX_CONNECTIONS = 12;

  @DynamicPort(systemProperty = "backendPort")
  private Integer backendPort;

  private HttpServer backendServer;
  private HttpClient testClient;
  private ExecutorService requestHandlerExecutor;

  // Tracks unique TCP connections by client address
  private final Set<String> uniqueConnectionIdentifiers = newSetFromMap(new ConcurrentHashMap<>());

  public Http2ReverseProxyMultiplexingTestCase(String serviceToLoad) {
    super(serviceToLoad);
  }

  @BeforeEach
  void setUp() {
    uniqueConnectionIdentifiers.clear();
  }

  @AfterEach
  void tearDown() {
    if (testClient != null) {
      testClient.stop();
    }
    if (backendServer != null) {
      backendServer.stop();
      backendServer.dispose();
    }
    if (requestHandlerExecutor != null) {
      requestHandlerExecutor.shutdownNow();
    }
  }

  @Test
  @Description("HTTP/1.1 baseline: Uses one connection per concurrent request (no multiplexing).")
  void testHttp1BaselineConnection() throws Exception {
    startServer(true, false); // HTTP/1 only
    startClient(true, false, TOTAL_REQUESTS); // Max 1000 connections

    runLoadTest(TOTAL_REQUESTS);
    int connections = uniqueConnectionIdentifiers.size();

    LOGGER.info("[Test 1] HTTP/1.1: {} connections created", connections);

    // HTTP/1.1 should use at least 80% of total requests as connections (no multiplexing)
    assertThat("HTTP/1.1 should use ~1 connection per concurrent request (no multiplexing)",
               connections, greaterThan((int) (TOTAL_REQUESTS * 0.8)));
  }

  @Test
  @Description("HTTP/2 (Unlimited): Should multiplex efficiently and keep connections low.")
  void testHttp2UnlimitedConnection() throws Exception {
    startServer(false, true); // HTTP/2 only
    startClient(false, true, -1); // unlimited connections

    runLoadTest(TOTAL_REQUESTS);
    int connections = uniqueConnectionIdentifiers.size();

    LOGGER.info("[Test 2] HTTP/2 Unlimited: {} connections created", connections);

    // HTTP/2 should multiplex efficiently - allow up to 80% of effectiveMaxConnections (400)
    assertThat("HTTP/2 should multiplex efficiently over a small pool of connections",
               connections, lessThan((int) (HTTP2_UNLIMITED_EFFECTIVE_MAX * 0.8)));
  }

  @Test
  @Description("HTTP/2 (Limited): Should respect explicit maxConnections and multiplex efficiently.")
  void testHttp2LimitedConnection() throws Exception {
    startServer(false, true); // HTTP/2 only
    startClient(false, true, CUSTOMER_MAX_CONNECTIONS); // Explicit limit 12

    runLoadTest(TOTAL_REQUESTS);
    int connections = uniqueConnectionIdentifiers.size();

    LOGGER.info("[Test 3] HTTP/2 Limited: {} connections created", connections);

    // Should stay within maxConnections + minConnections (12 + 1 = 13)
    assertThat("Should adhere to maxConnections limit and multiplex efficiently",
               connections, lessThan(CUSTOMER_MAX_CONNECTIONS + Math.max(1, CUSTOMER_MAX_CONNECTIONS / 50)));
  }

  @Test
  @Description("HTTP/2 (Single Connection): Should handle 10 requests over exactly 1 connection with 10 max streams.")
  void testHttp2SingleConnectionWithLimitedStreams() throws Exception {
    int requestCount = 10;
    startServer(false, true, requestCount);

    // configure client with 1 max/min connection and 10 max concurrent streams
    // this will internally use Http2AllocationStrategy with:
    testClient = service.getClientFactory().create(new HttpClientConfiguration.Builder()
        .setName("TestClient")
        .setHttp1Config(new Http1ProtocolConfig(false))
        .setHttp2Config(new Http2ProtocolConfig(true).setMaxConcurrentStreams(10L))
        .setUsePersistentConnections(true)
        .setMaxConnections(1)
        .build());
    testClient.start();

    runLoadTest(requestCount);
    int connections = uniqueConnectionIdentifiers.size();

    LOGGER.info("[Test 4] HTTP/2 Single Connection: {} connections created for {} requests with 10 max streams",
                connections, requestCount);

    assertThat("Should handle all requests over exactly 1 connection", connections, is(1));
  }

  private void runLoadTest(int requestCount) throws Exception {
    String url = "http://localhost:" + backendPort + BACKEND_ENDPOINT;
    HttpRequest request = HttpRequest.builder().uri(url).method("GET").build();
    HttpRequestOptions options = HttpRequestOptions.builder().responseTimeout(30000).build();

    // 1) fire all requests asynchronously (non-blocking)
    List<CompletableFuture<HttpResponse>> futures = new ArrayList<>(requestCount);
    for (int i = 0; i < requestCount; i++) {
      futures.add(testClient.sendAsync(request, options));
    }

    // 2) wait for all requests to reach the server and establish connections
    if (!requestsArrivedLatch.await(10, SECONDS)) {
      throw new AssertionError("Not all requests arrived at the server within timeout");
    }

    // verify all requests are being handled in parallel
    assertThat("All requests should be handled concurrently", maxConcurrentRequestsSeen, is(requestCount));

    // 3) release all requests to complete immediately (don't hold them blocked)
    requestLatch.countDown();

    // 4) wait for all requests to complete
    allOf(futures.toArray(new CompletableFuture[0])).get(25, SECONDS);

    // 5) verify all requests succeeded
    int successCount = 0;
    for (CompletableFuture<HttpResponse> future : futures) {
      if (future.get().getStatusCode() == 200) {
        successCount++;
      }
    }

    assertThat("All requests must succeed", successCount, is(requestCount));
  }

  private CountDownLatch requestLatch;
  private CountDownLatch requestsArrivedLatch;
  private volatile int concurrentRequestsInServer = 0;
  private volatile int maxConcurrentRequestsSeen = 0;

  private void startServer(boolean h1, boolean h2) throws Exception {
    startServer(h1, h2, TOTAL_REQUESTS);
  }

  private void startServer(boolean h1, boolean h2, int expectedRequestCount) throws Exception {
    // clean up any existing server first
    if (backendServer != null) {
      backendServer.stop();
      backendServer.dispose();
    }

    // clean up existing executor
    if (requestHandlerExecutor != null) {
      requestHandlerExecutor.shutdownNow();
    }

    // initialize latches: one to block requests, one to count arrivals
    requestLatch = new CountDownLatch(1);
    requestsArrivedLatch = new CountDownLatch(expectedRequestCount);
    concurrentRequestsInServer = 0;
    maxConcurrentRequestsSeen = 0;

    // create thread pool sized for concurrent request handling
    requestHandlerExecutor = Executors.newFixedThreadPool(expectedRequestCount);

    backendServer = service.getServerFactory().create(new HttpServerConfiguration.Builder()
        .setName("Backend")
        .setHost("localhost")
        .setPort(backendPort)
        .setHttp1Config(new Http1ProtocolConfig(h1))
        .setHttp2Config(new Http2ProtocolConfig(h2))
        .build());
    backendServer.start();

    backendServer.addRequestHandler(BACKEND_ENDPOINT, (reqCtx, callback) -> {
      String remoteAddress = reqCtx.getClientConnection().getRemoteHostAddress().toString();
      uniqueConnectionIdentifiers.add(remoteAddress);

      requestHandlerExecutor.submit(() -> {
        try {
          // track concurrent requests
          synchronized (this) {
            concurrentRequestsInServer++;
            maxConcurrentRequestsSeen = Math.max(maxConcurrentRequestsSeen, concurrentRequestsInServer);
          }

          // signal that this request has arrived at the server
          requestsArrivedLatch.countDown();

          // block here until the test releases the latch
          requestLatch.await();

          callback.responseReady(HttpResponse.builder().statusCode(200).entity(new ByteArrayHttpEntity("OK".getBytes())).build(),
                                 new IgnoreResponseStatusCallback());
        } catch (InterruptedException e) {
          currentThread().interrupt();
        } finally {
          synchronized (this) {
            concurrentRequestsInServer--;
          }
        }
      });
    });
  }

  private void startClient(boolean h1, boolean h2, int maxConnections) {
    // clean up any existing client first
    if (testClient != null) {
      testClient.stop();
    }

    testClient = service.getClientFactory().create(new HttpClientConfiguration.Builder()
        .setName("TestClient")
        .setHttp1Config(new Http1ProtocolConfig(h1))
        .setHttp2Config(new Http2ProtocolConfig(h2))
        .setUsePersistentConnections(true)
        .setMaxConnections(maxConnections)
        .build());
    testClient.start();
  }
}
