/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.client;

import static java.lang.Math.ceil;
import static java.nio.charset.StandardCharsets.UTF_8;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;

import org.mule.runtime.http.api.domain.entity.HttpEntity;
import org.mule.runtime.http.api.domain.entity.InputStreamHttpEntity;
import org.mule.tck.junit4.AbstractMuleTestCase;

import java.io.ByteArrayInputStream;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import io.netty.buffer.ByteBuf;
import org.junit.Test;
import org.reactivestreams.Subscription;
import reactor.core.CoreSubscriber;

public class ChunkedHttpEntityPublisherTestCase extends AbstractMuleTestCase {

  @Test
  public void requestLargerAmountOfChunks_EntityLargerThanBuffer() {
    String message = "Hello from client, post request to existing path. This is an stream";
    int smallBufferSize = 20;
    int expectedNumberOfChunks = calculateNumberOfChunks(message, smallBufferSize);
    HttpEntity entity = createEntity(message);
    ChunkedHttpEntityPublisher publisher = new ChunkedHttpEntityPublisher(entity, smallBufferSize);

    TestSubscriber testSubscriber = new TestSubscriber(10);
    publisher.subscribe(testSubscriber);

    assertThat(testSubscriber.getReceivedChunks(), is(expectedNumberOfChunks));
    assertThat(testSubscriber.getAggregatedContentAsString(), is(message));
    assertThat(testSubscriber.isCompleteCalled(), is(true));
  }

  @Test
  public void requestLargerAmountOfChunks_EntitySmallerThanBuffer() {
    String message = "Hello from client, post request to existing path. This is an stream";
    int bufferSize = message.length() * 2;
    int expectedNumberOfChunks = calculateNumberOfChunks(message, bufferSize);
    HttpEntity entity = createEntity(message);
    ChunkedHttpEntityPublisher publisher = new ChunkedHttpEntityPublisher(entity, bufferSize);

    TestSubscriber testSubscriber = new TestSubscriber(10);
    publisher.subscribe(testSubscriber);

    assertThat(testSubscriber.getReceivedChunks(), is(expectedNumberOfChunks));
    assertThat(testSubscriber.getAggregatedContentAsString(), is(message));
    assertThat(testSubscriber.isCompleteCalled(), is(true));
  }

  @Test
  public void requestExactAmountOfChunks_RequestLargerThanBuffer() {
    String message = "Hello from client, post request to existing path. This is an stream";
    int smallBufferSize = 20;
    int expectedNumberOfChunks = calculateNumberOfChunks(message, smallBufferSize);
    HttpEntity entity = createEntity(message);
    ChunkedHttpEntityPublisher publisher = new ChunkedHttpEntityPublisher(entity, smallBufferSize);

    TestSubscriber testSubscriber = new TestSubscriber(expectedNumberOfChunks);
    publisher.subscribe(testSubscriber);

    assertThat(testSubscriber.getReceivedChunks(), is(expectedNumberOfChunks));
    assertThat(testSubscriber.getAggregatedContentAsString(), is(message));
    assertThat(testSubscriber.isCompleteCalled(), is(false));
  }

  @Test
  public void requestLessAmountOfChunks_RequestLargerThanBuffer() {
    String message = "Hello from client, post request to existing path. This is an stream";
    int smallBufferSize = 20;
    int chunksToRequest = calculateNumberOfChunks(message, smallBufferSize) - 1;
    HttpEntity entity = createEntity(message);
    ChunkedHttpEntityPublisher publisher = new ChunkedHttpEntityPublisher(entity, smallBufferSize);

    TestSubscriber testSubscriber = new TestSubscriber(chunksToRequest);
    publisher.subscribe(testSubscriber);

    assertThat(testSubscriber.getReceivedChunks(), is(chunksToRequest));
    assertThat(testSubscriber.getAggregatedContentAsString(), is(message.substring(0, smallBufferSize * chunksToRequest)));
    assertThat(testSubscriber.isCompleteCalled(), is(false));
  }

  private static int calculateNumberOfChunks(String data, int chunkSize) {
    return (int) ceil(1.f * data.length() / chunkSize);
  }

  private static InputStreamHttpEntity createEntity(String message) {
    return new InputStreamHttpEntity(new ByteArrayInputStream(message.getBytes(UTF_8)));
  }

  public static class TestSubscriber implements CoreSubscriber<ByteBuf> {

    private final AtomicInteger receivedChunks = new AtomicInteger(0);
    private final AtomicBoolean isCompleteCalled = new AtomicBoolean(false);
    private final int chunksToRequestOnSubscribe;
    private final StringBuilder receivedContentBuilder;

    public TestSubscriber(int chunksToRequestOnSubscribe) {
      this.chunksToRequestOnSubscribe = chunksToRequestOnSubscribe;
      this.receivedContentBuilder = new StringBuilder();
    }

    @Override
    public void onSubscribe(Subscription subscription) {
      subscription.request(chunksToRequestOnSubscribe);
    }

    @Override
    public void onNext(ByteBuf byteBuf) {
      receivedChunks.incrementAndGet();

      byte[] bytes = new byte[byteBuf.readableBytes()];
      byteBuf.readBytes(bytes);
      receivedContentBuilder.append(new String(bytes));
    }

    @Override
    public void onError(Throwable throwable) {}

    @Override
    public void onComplete() {
      isCompleteCalled.set(true);
    }

    public int getReceivedChunks() {
      return receivedChunks.get();
    }

    public boolean isCompleteCalled() {
      return isCompleteCalled.get();
    }

    public String getAggregatedContentAsString() {
      return receivedContentBuilder.toString();
    }
  }
}
