/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.test.common.server;

import static org.mule.runtime.api.metadata.MediaType.TEXT;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.OK;
import static org.mule.runtime.http.api.HttpConstants.Method.POST;
import static org.mule.runtime.http.api.HttpHeaders.Names.CONTENT_TYPE;

import static java.lang.Integer.parseInt;
import static java.lang.String.format;
import static java.lang.Thread.sleep;
import static java.util.Collections.singletonList;
import static java.util.concurrent.Executors.newCachedThreadPool;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.fail;

import org.mule.runtime.http.api.domain.entity.InputStreamHttpEntity;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.runtime.http.api.server.HttpServerConfiguration;
import org.mule.runtime.http.api.server.async.HttpResponseReadyCallback;

import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.hc.client5.http.fluent.Request;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

public class SlowResponsesWithRecursionTestCase extends AbstractHttpServerTestCase {

  private static final long SOME_TIME = 100;
  private static final int REQUESTS = 20;
  private static final int LATCH_TIMEOUT = 5000;
  private static final String ENDPOINT = "/recursiveCalls";
  private final CountDownLatch started = new CountDownLatch(REQUESTS);
  private final CountDownLatch finished = new CountDownLatch(REQUESTS);
  private final ExecutorService executor = newCachedThreadPool();

  public SlowResponsesWithRecursionTestCase(String serviceToLoad) {
    super(serviceToLoad);
  }

  @Override
  protected void setUpServer() throws Exception {
    server = service.getServerFactory().create(configureServer(new HttpServerConfiguration.Builder()
        .setHost("localhost")
        .setPort(port)
        .setName(getServerName())
        .setSchedulerSupplier(() -> getSchedulerService().ioScheduler()))
            .build());
    server.start();
  }

  @BeforeEach
  public void setUp() throws Exception {
    setUpServer();
    registerHandler();
  }

  private void registerHandler() {
    server.addRequestHandler(singletonList(POST.name()), ENDPOINT, (requestContext, responseCallback) -> {
      try {
        sleep(SOME_TIME);
      } catch (InterruptedException e) {
        throw new RuntimeException(e);
      }
      try {
        int message = parseInt(new String(requestContext.getRequest().getEntity().getBytes()));
        respondInAnotherThread(message, responseCallback);
      } catch (IOException e) {
        throw new RuntimeException(e);
      }
    }).start();
  }

  @Override
  protected String getServerName() {
    return "sarasa";
  }

  @Test
  void test() throws Exception {
    AtomicInteger correctCount = new AtomicInteger(0);
    for (int i = 0; i < REQUESTS; i++) {
      performRequestInThread(correctCount);
    }
    started.await();
    if (!finished.await(LATCH_TIMEOUT, TimeUnit.MILLISECONDS)) {
      fail(format("We only got %d responses out of %d", correctCount.get(), REQUESTS));
    }
  }

  private void performRequestInThread(AtomicInteger counter) {
    executor.submit(() -> {
      try {
        started.countDown();
        Request request = Request.post(urlForPath(ENDPOINT));
        request.bodyByteArray("2".getBytes());
        org.apache.hc.core5.http.HttpResponse response = request.execute().returnResponse();
        assertThat(response.getCode(), is(OK.getStatusCode()));
        counter.incrementAndGet();
        finished.countDown();
      } catch (Exception e) {
        //
      }
    });
  }

  private void respondInAnotherThread(int message, HttpResponseReadyCallback responseCallback) {
    executor.submit(() -> {
      try {
        if (message > 0) {
          Request request = Request.post(urlForPath(ENDPOINT));
          request.bodyByteArray(Integer.toString(message - 1).getBytes());
          org.apache.hc.core5.http.HttpResponse response = request.execute().returnResponse();
          assertThat(response.getCode(), is(OK.getStatusCode()));
        }
      } catch (IOException e) {
        //
      }

      InputStream slow = new InputStream() {

        private int count = 0;

        @Override
        public int read() throws IOException {
          try {
            sleep(SOME_TIME / 10);
          } catch (InterruptedException e) {
            throw new RuntimeException(e);
          }
          return (count++ <= 100) ? 1 : -1;
        }
      };

      responseCallback.responseReady(HttpResponse.builder().entity(new InputStreamHttpEntity(slow))
          .addHeader(CONTENT_TYPE, TEXT.toRfcString())
          .build(), new IgnoreResponseStatusCallback());
    });
  }
}
