/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.util;

import static org.mule.tck.probe.PollingProber.probe;

import static java.lang.Thread.sleep;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Optional.empty;
import static java.util.Optional.ofNullable;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.service.http.netty.impl.streaming.BlockingBidirectionalStream;
import org.mule.tck.junit4.AbstractMuleTestCase;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;

import junit.framework.AssertionFailedError;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;

public class BlockingBufferTestCase extends AbstractMuleTestCase {

  private static final String TEST_PAYLOAD =
      "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";

  private static final Logger LOGGER = getLogger(BlockingBufferTestCase.class);

  private OutputStream sink;
  private TestConsumer consumer;

  @Before
  public void setUp() {
    BlockingBidirectionalStream blockingBuffer = new BlockingBidirectionalStream();

    consumer = new TestConsumer(blockingBuffer.getInputStream(), 8);
    consumer.start();

    sink = blockingBuffer.getOutputStream();
  }

  @After
  public void tearDown() throws InterruptedException {
    if (consumer != null) {
      consumer.join();
    }
  }

  @Test
  public void consumerBlocksWhenBufferIsEmpty() throws InterruptedException, IOException {
    // The test has to work regardless of this sleep(), but sleeping should evidence
    // the bug if it exists.
    sleep(500);

    assertThat(consumer.finishedReading(), is(false));
    assertThat(consumer.getConsumedData().length, is(0));

    consumer.interrupt();

    probe(() -> {
      Exception error = consumer.getErrorWhileConsuming().orElseThrow(errorShouldBePresent());
      assertThat(error, instanceOf(IOException.class));
      assertThat(error.getCause(), instanceOf(InterruptedException.class));
      return true;
    });
  }

  @Test
  public void consumerUnblocksWhenBufferIsClosedAndEmpty() throws InterruptedException, IOException {
    // The test has to work regardless of this sleep(), but sleeping should evidence
    // the bug if it exists.
    sleep(500);

    assertThat(consumer.finishedReading(), is(false));
    assertThat(consumer.getConsumedData().length, is(0));

    sink.close();

    probe(() -> {
      assertThat(consumer.finishedReading(), is(true));
      assertThat(consumer.getErrorWhileConsuming(), is(empty()));
      assertThat(consumer.getConsumedData().length, is(0));
      return true;
    });
  }

  @Test
  public void writeAndReadAPayloadWithDifferentChunkSizes() throws IOException {
    String[] words = TEST_PAYLOAD.split(" ");
    for (String word : words) {
      // Each word will have its length.
      sink.write(word.getBytes(UTF_8));
    }
    sink.close();

    // As we split by space, the expected consumed data doesn't contain them.
    final String testPayloadWithoutSpaces = TEST_PAYLOAD.replace(" ", "");
    probe(() -> {
      assertThat(new String(consumer.getConsumedData()), is(testPayloadWithoutSpaces));
      assertThat(consumer.finishedReading(), is(true));
      assertThat(consumer.getErrorWhileConsuming(), is(empty()));
      return true;
    });
  }

  @Test
  public void writeBytePerByte() throws IOException {
    for (byte b : TEST_PAYLOAD.getBytes(UTF_8)) {
      sink.write(b);
    }
    sink.close();

    probe(() -> {
      assertThat(new String(consumer.getConsumedData()), is(TEST_PAYLOAD));
      assertThat(consumer.finishedReading(), is(true));
      assertThat(consumer.getErrorWhileConsuming(), is(empty()));
      return true;
    });
  }

  @Test
  public void writeLessBytesThanBufferSize() throws IOException {
    byte[] abc = "abc".getBytes(UTF_8);

    sink.write(abc);

    probe(() -> {
      // The data is read and the consumer is blocked waiting for another chunk.
      assertThat(new String(consumer.getConsumedData()), is("abc"));
      assertThat(consumer.finishedReading(), is(false));
      return true;
    });

    sink.close();

    probe(() -> {
      assertThat(consumer.finishedReading(), is(true));
      assertThat(consumer.getErrorWhileConsuming(), is(empty()));
      return true;
    });
  }

  private Supplier<? extends Throwable> errorShouldBePresent() {
    return () -> new AssertionFailedError("Error should be present");
  }

  private static class TestConsumer extends Thread {

    private final InputStream inputStream;
    private final int consumerBufferSize;
    private final ByteArrayOutputStream consumedData;
    private final AtomicBoolean continueReading;
    private Exception errorWhileConsuming;

    public TestConsumer(InputStream inputStream, int consumerBufferSize) {
      this.inputStream = inputStream;
      this.consumerBufferSize = consumerBufferSize;
      this.consumedData = new ByteArrayOutputStream();
      this.continueReading = new AtomicBoolean(true);
    }

    public synchronized byte[] getConsumedData() {
      return consumedData.toByteArray();
    }

    public boolean finishedReading() {
      return !continueReading.get();
    }

    public Optional<Exception> getErrorWhileConsuming() {
      return ofNullable(errorWhileConsuming);
    }

    @Override
    public void run() {
      while (continueReading.get()) {
        byte[] consumerBuffer = new byte[consumerBufferSize];
        try {
          int bytesRead = inputStream.read(consumerBuffer, 0, consumerBufferSize);
          if (bytesRead == -1 || bytesRead == 0) {
            continueReading.set(false);
          } else {
            synchronized (this) {
              LOGGER.debug("Reading this chunk [{}]", new String(consumerBuffer, 0, bytesRead));
              consumedData.write(consumerBuffer, 0, bytesRead);
            }
          }
        } catch (IOException e) {
          LOGGER.error("Found error while consuming", e);
          this.errorWhileConsuming = e;
          continueReading.set(false);
        }
      }
    }
  }
}
