/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.test.common.client;

import static org.mule.runtime.http.api.HttpConstants.HttpStatus.BAD_REQUEST;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.OK;
import static org.mule.runtime.http.api.HttpHeaders.Names.CONTENT_LENGTH;
import static org.mule.runtime.http.api.HttpHeaders.Names.CONTENT_TYPE;
import static org.mule.service.http.netty.impl.server.ForwardingToListenerHandler.ALLOW_PAYLOAD_FOR_UNDEFINED_METHODS;
import static org.mule.service.http.test.netty.AllureConstants.HttpStory.TRANSFER_TYPE;

import static java.lang.Long.valueOf;
import static java.lang.String.format;
import static java.util.OptionalLong.empty;
import static java.util.OptionalLong.of;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;

import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.client.HttpClientConfiguration;
import org.mule.runtime.http.api.domain.entity.ByteArrayHttpEntity;
import org.mule.runtime.http.api.domain.entity.EmptyHttpEntity;
import org.mule.runtime.http.api.domain.entity.HttpEntity;
import org.mule.runtime.http.api.domain.entity.InputStreamHttpEntity;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.runtime.http.api.domain.message.response.HttpResponseBuilder;
import org.mule.service.http.netty.impl.message.content.StreamedMultipartHttpEntity;

import java.io.ByteArrayInputStream;
import java.util.OptionalLong;
import java.util.function.Consumer;

import io.qameta.allure.Story;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

@Story(TRANSFER_TYPE)
public abstract class AbstractHttpTransferLengthTestCase extends AbstractHttpClientTestCase {

  private static final String RESPONSE = "TEST";
  private static final String REQUEST = "tests";
  private static final String BYTE = "/byte";
  private static final String MULTIPART = "/multipart";
  private static final String STREAM = "/stream";
  private static final String CHUNKED = "/chunked";
  private static final String MULTIPART_DATA_FORMAT = """
      --bounds\r
      Content-Type: text/plain\r
      Content-Disposition: form-data; name="part1"\r
      \r
      %s\r
      --bounds--\r
      """;

  private HttpClient client;

  private final boolean isAllowPayload;
  private final boolean isAllowPayloadDefault;

  protected AbstractHttpTransferLengthTestCase(String serviceToLoad, boolean isAllowPayload) {
    super(serviceToLoad);

    this.isAllowPayload = isAllowPayload;
    isAllowPayloadDefault = setAllowPayloadForUndefinedMethod(isAllowPayload);
  }

  protected boolean setAllowPayloadForUndefinedMethod(boolean isAllowPayload) {
    boolean defaultIsAllowPayload = ALLOW_PAYLOAD_FOR_UNDEFINED_METHODS;
    ALLOW_PAYLOAD_FOR_UNDEFINED_METHODS = isAllowPayload;

    return defaultIsAllowPayload;
  }

  protected void restoreAllowPayloadForUndefinedMethod(boolean isAllowPayloadDefault) {
    ALLOW_PAYLOAD_FOR_UNDEFINED_METHODS = isAllowPayloadDefault;
  }

  @BeforeEach
  public void createClient() {
    HttpClientConfiguration clientConf =
        new HttpClientConfiguration.Builder().setName("transfer-type-test").setStreaming(true).build();
    client = service.getClientFactory().create(clientConf);
    client.start();
  }

  @Override
  @AfterEach
  public void tearDown() throws Exception {
    super.tearDown();
    if (client != null) {
      client.stop();
    }

    restoreAllowPayloadForUndefinedMethod(isAllowPayloadDefault);
  }

  @Override
  protected HttpResponse setUpHttpResponse(HttpRequest request) {
    String path = request.getPath();
    HttpEntity entity = request.getEntity();
    HttpResponseBuilder builder = HttpResponse.builder();
    try {
      OptionalLong expectedRequestLength = of(valueOf(REQUEST.length()));
      if (BYTE.equals(path)) {
        assertThat(request.containsHeader(CONTENT_LENGTH), is(true));
        assertThat(entity.isStreaming(), is(true));

        builder.entity(new ByteArrayHttpEntity(RESPONSE.getBytes()));
      } else if (MULTIPART.equals(path)) {
        assertThat(request.containsHeader(CONTENT_LENGTH), is(true));
        expectedRequestLength = of(103L);
        assertThat(entity, is(instanceOf(StreamedMultipartHttpEntity.class)));

        builder
            .entity(new ByteArrayHttpEntity(format(MULTIPART_DATA_FORMAT, "TEST").getBytes()))
            .addHeader(CONTENT_TYPE, "multipart/form-data; boundary=\"bounds\"");
      } else if (STREAM.equals(path)) {
        assertThat(request.containsHeader(CONTENT_LENGTH), is(true));
        assertThat(entity, is(instanceOf(InputStreamHttpEntity.class)));

        builder.entity(new InputStreamHttpEntity(new ByteArrayInputStream("TEST".getBytes()), 4L));
      } else if (CHUNKED.equals(path)) {
        assertThat(request.containsHeader(CONTENT_LENGTH), is(false));
        assertThat(entity, is(instanceOf(InputStreamHttpEntity.class)));
        expectedRequestLength = empty();

        builder.entity(new InputStreamHttpEntity(new ByteArrayInputStream("TEST".getBytes())));
      } else { // empty request
        assertThat(request.containsHeader(CONTENT_LENGTH), is(false));
        expectedRequestLength = of(0L);
        assertThat(entity, is(instanceOf(EmptyHttpEntity.class)));
      }
      assertThat(request.getEntity().getBytesLength(), is(expectedRequestLength));
      return builder.build();
    } catch (AssertionError e) {
      return builder.statusCode(500).entity(new ByteArrayHttpEntity(e.getMessage().getBytes())).build();
    }
  }

  @Test
  void propagatesLengthWhenByte() throws Exception {
    HttpRequest request = HttpRequest.builder()
        .uri(getUri() + BYTE)
        .entity(new ByteArrayHttpEntity(REQUEST.getBytes())).build();
    send(request, false, response -> {
      assertThat(response.getEntity().getBytesLength().getAsLong(), is(equalTo(4L)));
      assertThat(response.getEntity().isStreaming(), is(true));
    });
  }

  @Test
  void propagatesLengthWhenMultipart() throws Exception {
    HttpRequest request = HttpRequest.builder()
        .uri(getUri() + MULTIPART)
        .addHeader(CONTENT_TYPE, "multipart/form-data; boundary=bounds")
        .entity(new ByteArrayHttpEntity(format(MULTIPART_DATA_FORMAT, REQUEST).getBytes()))
        .build();
    send(request, false, response -> {
      assertThat(response.getEntity().getBytesLength().getAsLong(), is(equalTo(102L)));
      assertThat(response.getEntity().isComposed(), is(true));
    });
  }

  @Test
  void propagatesLengthWhenEmpty() throws Exception {
    HttpRequest request = HttpRequest.builder().uri(getUri() + "/empty").build();
    send(request, true, response -> {
      assertThat(response.getEntity().getBytesLength().getAsLong(), is(equalTo(0L)));
    });
  }

  @Test
  void propagatesLengthWhenStream() throws Exception {
    HttpRequest request = HttpRequest.builder()
        .uri(getUri() + STREAM)
        .entity(new InputStreamHttpEntity(new ByteArrayInputStream(REQUEST.getBytes()), 5L))
        .build();
    send(request, false, response -> {
      assertThat(response.getEntity().getBytesLength().getAsLong(), is(equalTo(4L)));
      assertThat(response.getEntity().isStreaming(), is(true));
    });
  }

  @Test
  void doesNotPropagateLengthWhenChunked() throws Exception {
    HttpRequest request = HttpRequest.builder()
        .uri(getUri() + CHUNKED)
        .entity(new InputStreamHttpEntity(new ByteArrayInputStream(REQUEST.getBytes())))
        .build();

    send(request, false, response -> {
      assertThat(response.getEntity().getBytesLength(), is(OptionalLong.empty()));
      assertThat(response.getEntity().isStreaming(), is(true));
    });
  }

  private void send(HttpRequest request, boolean hasEmptyPayload, Consumer<HttpResponse> onSuccessResponse) throws Exception {
    HttpResponse response = client.send(request, getDefaultOptions(TIMEOUT));

    if (isAllowPayload || hasEmptyPayload) {
      assertThat(response.getStatusCode(), is(OK.getStatusCode()));
      onSuccessResponse.accept(response);
    } else {
      assertThat(response.getStatusCode(), is(BAD_REQUEST.getStatusCode()));
    }
  }

}
