/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.test.netty.impl.server;

import static org.mule.runtime.http.api.HttpConstants.HttpStatus.INTERNAL_SERVER_ERROR;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.OK;
import static org.mule.service.http.test.netty.utils.TestUtils.measuringNanoseconds;

import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalToCompressingWhiteSpace;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;

import org.mule.runtime.http.api.server.HttpServer;
import org.mule.service.http.netty.impl.message.content.StringHttpEntity;
import org.mule.service.http.netty.impl.server.AcceptedConnectionChannelInitializer;
import org.mule.service.http.netty.impl.server.NettyHttpServer;
import org.mule.service.http.netty.impl.server.util.HttpListenerRegistry;
import org.mule.service.http.test.common.AbstractHttpTestCase;
import org.mule.service.http.test.netty.tck.ExecutorRule;
import org.mule.service.http.test.netty.utils.NoOpResponseStatusCallback;
import org.mule.service.http.test.netty.utils.ResponseWithoutHeaders;
import org.mule.service.http.test.netty.utils.TcpTextClient;
import org.mule.tck.junit4.rule.DynamicPort;

import java.io.IOException;
import java.net.InetSocketAddress;

import org.apache.commons.io.IOUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;

import io.qameta.allure.Issue;

@Issue("W-17464403")
public class ServerConnectionTimeoutTestCase extends AbstractHttpTestCase {

  @ClassRule
  public static ExecutorRule executorRule = new ExecutorRule();

  // We want the timeouts to be caused by connection timeout, not by read timeout.
  private static final int SMALL_CONNECTION_TIMEOUT_MILLIS = 500;
  private static final long LARGE_READ_TIMEOUT_MILLIS = 1000000L;

  @Rule
  public DynamicPort serverPort = new DynamicPort("serverPort");

  private HttpServer httpServer;

  @Before
  public void setup() throws Exception {
    HttpListenerRegistry listenerRegistry = new HttpListenerRegistry();
    httpServer = NettyHttpServer.builder()
        .withName("test-server")
        .withServerAddress(new InetSocketAddress(serverPort.getNumber()))
        .withHttpListenerRegistry(listenerRegistry)
        .withShutdownTimeout(() -> 5000L)
        .withClientChannelHandler(new AcceptedConnectionChannelInitializer(listenerRegistry, "test-server", true,
                                                                           SMALL_CONNECTION_TIMEOUT_MILLIS,
                                                                           LARGE_READ_TIMEOUT_MILLIS, null, 300,
                                                                           executorRule.getExecutor()))
        .build();
    httpServer.start();
    httpServer.addRequestHandler("/test", (requestContext, responseCallback) -> executorRule.getExecutor().submit(() -> {
      try {
        var asString = IOUtils.toString(requestContext.getRequest().getEntity().getContent(), UTF_8);
        responseCallback.responseReady(new ResponseWithoutHeaders(OK, new StringHttpEntity(asString)),
                                       new NoOpResponseStatusCallback());
      } catch (IOException e) {
        responseCallback.responseReady(new ResponseWithoutHeaders(INTERNAL_SERVER_ERROR, new StringHttpEntity(e.toString())),
                                       new NoOpResponseStatusCallback());
      }
    }));
  }

  @After
  public void tearDown() {
    if (!httpServer.isStopped()) {
      httpServer.stop();
    }
    httpServer.dispose();
  }

  @Test
  public void sendPartialRequestShouldTimeout() throws IOException {
    try (TcpTextClient tcpTextClient = new TcpTextClient("localhost", serverPort.getNumber())) {
      // Send one request...
      tcpTextClient.sendString("""
          GET /test HTTP/1.1
          Host: localhost: %s

          """.formatted(serverPort.getNumber()));

      // The response should be received normally.
      assertThat(tcpTextClient.receiveUntil("\r\n\r\n"), equalToCompressingWhiteSpace("""
          HTTP/1.1 200 OK
          content-length: 0"""));

      // Send second request (persistent connection)
      tcpTextClient.sendString("""
          GET /test HTTP/1.1
          Host: localhost: %d

          """.formatted(serverPort.getNumber()));

      // The response should be received normally again.
      assertThat(tcpTextClient.receiveUntil("\r\n\r\n"), equalToCompressingWhiteSpace("""
          HTTP/1.1 200 OK
          content-length: 0"""));

      // Connection will be closed by connection timeout.
      long elapsedNanos = measuringNanoseconds(() -> {
        String content = tcpTextClient.receiveUntil("\r\n\r\n");
        assertThat(content, is(""));
      });

      long toleranceNanos = MILLISECONDS.toNanos(50L);
      long connectionTimeoutNanos = MILLISECONDS.toNanos(SMALL_CONNECTION_TIMEOUT_MILLIS);
      assertThat(elapsedNanos, is(lessThan(connectionTimeoutNanos + toleranceNanos)));
    }
  }
}
