/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.test.netty.impl.server;

import static org.mule.service.http.test.netty.utils.TestUtils.measuringNanoseconds;

import static java.lang.Thread.sleep;
import static java.util.concurrent.Executors.newFixedThreadPool;
import static java.util.concurrent.TimeUnit.DAYS;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.lessThan;

import org.mule.runtime.api.util.concurrent.Latch;
import org.mule.service.http.netty.impl.server.ConnectionsCounterHandler;
import org.mule.tck.junit4.AbstractMuleTestCase;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import io.netty.channel.ChannelHandlerContext;
import io.qameta.allure.Issue;
import org.junit.Rule;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;

@Issue("W-15867731")
public class ConnectionsCounterHandlerTestCase extends AbstractMuleTestCase {

  private static final long HALF_SECOND_IN_NANOS = 500_000_000L;
  private static final long TWO_SECONDS_IN_NANOS = 2_000_000_000L;
  private static final ExecutorService executorService = newFixedThreadPool(1);

  @Rule
  public MockitoRule mockitoRule = MockitoJUnit.rule();

  @Mock
  private ChannelHandlerContext ctx;

  private final ConnectionsCounterHandler handler = new ConnectionsCounterHandler();

  @Test
  public void waitsWhenAChannelIsActive() throws Exception {
    long timeoutNanos = HALF_SECOND_IN_NANOS;
    handler.channelActive(ctx);
    long nanosElapsed = measuringNanoseconds(() -> handler.waitForConnectionsToBeClosed(timeoutNanos, NANOSECONDS));
    assertThat(nanosElapsed, greaterThanOrEqualTo(timeoutNanos));
  }

  @Test
  public void doesNotWaitWhenNoChannelIsActive() throws Exception {
    int numberOfConnections = 10;
    for (int i = 0; i < numberOfConnections; i++) {
      handler.channelActive(ctx);
    }
    for (int i = 0; i < numberOfConnections; i++) {
      handler.channelInactive(ctx);
    }
    long nanosElapsed = measuringNanoseconds(() -> handler.waitForConnectionsToBeClosed(7, DAYS));
    assertThat(nanosElapsed, lessThan(HALF_SECOND_IN_NANOS));
  }

  @Test
  public void lastChannelInactiveCallbackUnblocksTheWait() throws Exception {
    handler.channelActive(ctx);

    Latch timeMeasurementStarted = new Latch();
    Future<Long> elapsedNanosFuture =
        executorService.submit(() -> measuringNanoseconds(() -> {
          timeMeasurementStarted.release();
          handler.waitForConnectionsToBeClosed(7, DAYS);
        }));
    timeMeasurementStarted.await();

    sleep(NANOSECONDS.toMillis(HALF_SECOND_IN_NANOS));

    handler.channelInactive(ctx);
    assertThat(elapsedNanosFuture.get(), greaterThanOrEqualTo(HALF_SECOND_IN_NANOS));
    assertThat(elapsedNanosFuture.get(), lessThan(TWO_SECONDS_IN_NANOS));
  }

  @Test
  public void zeroTimeoutMeansNoWait() throws Exception {
    handler.channelActive(ctx);
    Long elapsedNanos = measuringNanoseconds(() -> handler.waitForConnectionsToBeClosed(0, SECONDS));
    assertThat(elapsedNanos, lessThan(HALF_SECOND_IN_NANOS));
  }

  @Test
  public void interruptingTheThreadUnblocksTheWait() throws Exception {
    handler.channelActive(ctx);
    Thread thread = new Thread(() -> handler.waitForConnectionsToBeClosed(7, DAYS));
    thread.start();
    thread.interrupt();
    long nanosElapsed = measuringNanoseconds(thread::join);
    assertThat(nanosElapsed, lessThan(HALF_SECOND_IN_NANOS));
  }

  @Test
  public void waitLessThanOneMillisecond() throws Exception {
    handler.channelActive(ctx);
    long lessThanOneMillisecondInNanos = MILLISECONDS.toNanos(1) - 1;
    long nanosElapsed =
        measuringNanoseconds(() -> handler.waitForConnectionsToBeClosed(lessThanOneMillisecondInNanos, NANOSECONDS));
    assertThat(nanosElapsed, greaterThanOrEqualTo(10L));
  }
}
