/*
 * Copyright (c) MuleSoft, Inc.  All rights reserved.  http://www.mulesoft.com
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.service.http.impl.functional;

import static java.lang.Long.valueOf;
import static java.util.Collections.singletonList;
import static java.util.OptionalLong.empty;
import static java.util.OptionalLong.of;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.OK;
import static org.mule.runtime.http.api.HttpHeaders.Names.CONTENT_TYPE;
import static org.mule.service.http.impl.AllureConstants.HttpFeature.HttpStory.TRANSFER_TYPE;

import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.client.HttpClientConfiguration;
import org.mule.runtime.http.api.domain.entity.ByteArrayHttpEntity;
import org.mule.runtime.http.api.domain.entity.EmptyHttpEntity;
import org.mule.runtime.http.api.domain.entity.HttpEntity;
import org.mule.runtime.http.api.domain.entity.InputStreamHttpEntity;
import org.mule.runtime.http.api.domain.entity.multipart.HttpPart;
import org.mule.runtime.http.api.domain.entity.multipart.MultipartHttpEntity;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.runtime.http.api.domain.message.response.HttpResponseBuilder;
import org.mule.service.http.impl.functional.client.AbstractHttpClientTestCase;
import org.mule.service.http.impl.service.domain.entity.multipart.StreamedMultipartHttpEntity;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import java.io.ByteArrayInputStream;
import java.util.OptionalLong;

import io.qameta.allure.Story;

@Story(TRANSFER_TYPE)
public class HttpTransferLengthTestCase extends AbstractHttpClientTestCase {

  private static final String RESPONSE = "TEST";
  private static final String REQUEST = "tests";
  private static final String BYTE = "/byte";
  private static final String MULTIPART = "/multipart";
  private static final String STREAM = "/stream";
  private static final String CHUNKED = "/chunked";

  private HttpClient client;

  public HttpTransferLengthTestCase(String serviceToLoad) {
    super(serviceToLoad);
  }

  @Before
  public void createClient() {
    HttpClientConfiguration clientConf = new HttpClientConfiguration.Builder().setName("transfer-type-test").build();
    client = service.getClientFactory().create(clientConf);
    client.start();
  }

  @After
  public void closeClient() {
    if (client != null) {
      client.stop();
    }
  }

  @Override
  protected HttpResponse setUpHttpResponse(HttpRequest request) {
    String path = request.getPath();
    HttpEntity entity = request.getEntity();
    HttpResponseBuilder builder = HttpResponse.builder();
    try {
      OptionalLong expectedRequestLength = of(valueOf(REQUEST.length()));
      if (BYTE.equals(path)) {
        assertThat(entity, is(instanceOf(InputStreamHttpEntity.class)));

        builder.entity(new ByteArrayHttpEntity(RESPONSE.getBytes()));
      } else if (MULTIPART.equals(path)) {
        expectedRequestLength = of(142L);
        assertThat(entity, is(instanceOf(StreamedMultipartHttpEntity.class)));

        HttpPart part = new HttpPart("part1", "TEST".getBytes(), "text/plain", 4);
        builder
            .entity(new MultipartHttpEntity(singletonList(part)))
            .addHeader(CONTENT_TYPE, "multipart/form-data; boundary=\"bounds\"");
      } else if (STREAM.equals(path)) {
        assertThat(entity, is(instanceOf(InputStreamHttpEntity.class)));

        builder.entity(new InputStreamHttpEntity(new ByteArrayInputStream("TEST".getBytes()), 4L));
      } else if (CHUNKED.equals(path)) {
        assertThat(entity, is(instanceOf(InputStreamHttpEntity.class)));
        expectedRequestLength = empty();

        builder.entity(new InputStreamHttpEntity(new ByteArrayInputStream("TEST".getBytes())));
      } else {
        expectedRequestLength = of(0L);
        assertThat(entity, is(instanceOf(EmptyHttpEntity.class)));
      }
      assertThat(request.getEntity().getBytesLength(), is(expectedRequestLength));
      return builder.build();
    } catch (AssertionError e) {
      return builder.statusCode(500).entity(new ByteArrayHttpEntity(e.getMessage().getBytes())).build();
    }
  }

  @Test
  public void propagatesLengthWhenByte() throws Exception {
    HttpRequest request = HttpRequest.builder()
        .uri(getUri() + BYTE)
        .entity(new ByteArrayHttpEntity(REQUEST.getBytes())).build();
    HttpResponse response = send(request);

    assertThat(response.getEntity().getBytesLength().getAsLong(), is(equalTo(4L)));
    assertThat(response.getEntity(), instanceOf(InputStreamHttpEntity.class));
  }

  @Test
  public void propagatesLengthWhenMultipart() throws Exception {
    HttpPart part = new HttpPart("part1", REQUEST.getBytes(), "text/plain", 5);
    HttpRequest request = HttpRequest.builder()
        .uri(getUri() + MULTIPART)
        .addHeader(CONTENT_TYPE, "multipart/form-data; boundary=\"bounds\"")
        .entity(new MultipartHttpEntity(singletonList(part)))
        .build();
    HttpResponse response = send(request);

    assertThat(response.getEntity().getBytesLength().getAsLong(), is(equalTo(102L)));
    assertThat(response.getEntity(), instanceOf(StreamedMultipartHttpEntity.class));
  }

  @Test
  public void propagatesLengthWhenEmpty() throws Exception {
    HttpRequest request = HttpRequest.builder().uri(getUri() + "/empty").build();
    HttpResponse response = send(request);

    assertThat(response.getEntity().getBytesLength().getAsLong(), is(equalTo(0L)));
    assertThat(response.getEntity(), instanceOf(EmptyHttpEntity.class));
  }

  @Test
  public void propagatesLengthWhenStream() throws Exception {
    HttpRequest request = HttpRequest.builder()
        .uri(getUri() + STREAM)
        .entity(new InputStreamHttpEntity(new ByteArrayInputStream(REQUEST.getBytes()), 5L))
        .build();
    HttpResponse response = send(request);

    assertThat(response.getEntity().getBytesLength().getAsLong(), is(equalTo(4L)));
    assertThat(response.getEntity(), instanceOf(InputStreamHttpEntity.class));
  }

  @Test
  public void doesNotPropagateLengthWhenChunked() throws Exception {
    HttpRequest request = HttpRequest.builder()
        .uri(getUri() + CHUNKED)
        .entity(new InputStreamHttpEntity(new ByteArrayInputStream(REQUEST.getBytes())))
        .build();

    HttpResponse response = send(request);

    assertThat(response.getEntity().getBytesLength(), is(OptionalLong.empty()));
    assertThat(response.getEntity(), instanceOf(InputStreamHttpEntity.class));
  }

  private HttpResponse send(HttpRequest request) throws Exception {
    HttpResponse response = client.send(request, getDefaultOptions(TIMEOUT));

    assertThat(response.getStatusCode(), is(OK.getStatusCode()));
    return response;
  }

}
