/*
 * (c) 2003-2021 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package com.mulesoft.service.http.impl.functional.ws;

import static java.lang.Integer.MAX_VALUE;
import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static org.apache.commons.lang3.RandomStringUtils.randomAlphanumeric;
import static org.apache.commons.lang3.RandomUtils.nextBytes;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mule.runtime.api.metadata.DataType.INPUT_STREAM;
import static org.mule.runtime.api.metadata.DataType.TEXT_STRING;
import static org.mule.runtime.api.metadata.MediaType.BINARY;
import static org.mule.runtime.core.api.util.IOUtils.toByteArray;
import static org.mule.tck.probe.PollingProber.check;
import static org.mule.tck.probe.PollingProber.checkNot;
import static org.mule.tck.probe.PollingProber.probe;

import org.mule.runtime.api.metadata.TypedValue;
import org.mule.runtime.core.api.retry.policy.NoRetryPolicyTemplate;
import org.mule.runtime.core.api.retry.policy.RetryPolicyTemplate;
import org.mule.runtime.core.api.retry.policy.SimpleRetryPolicyTemplate;
import org.mule.runtime.core.api.util.func.CheckedConsumer;
import org.mule.runtime.http.api.client.HttpRequestOptions;
import org.mule.runtime.http.api.client.ws.WebSocketCallback;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.server.ws.WebSocketConnectionHandler;
import org.mule.runtime.http.api.server.ws.WebSocketConnectionRejectedException;
import org.mule.runtime.http.api.server.ws.WebSocketHandler;
import org.mule.runtime.http.api.server.ws.WebSocketMessageHandler;
import org.mule.runtime.http.api.server.ws.WebSocketRequest;
import org.mule.runtime.http.api.ws.WebSocket;
import org.mule.runtime.http.api.ws.WebSocketCloseCode;
import org.mule.tck.SimpleUnitTestSupportCustomScheduler;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

public class BroadcastWebSocketTestCase extends AbstractWebSocketTestCase {

  private static final String PATH = "/quotes";
  private static final String SHORT_TEXT_MESSAGE = "Hello There!";
  private static final int TIMEOUT_MILLIS = 15000;
  private static final int POLL_DELAY_MILLIS = 1000;
  private static final int LARGE_MESSAGE_SIZE = (24 * 1024) + 100;
  private static final int CLIENT_COUNT = 3;

  private final List<WebSocket> serverSockets = new ArrayList<>(CLIENT_COUNT);
  private final List<WebSocket> clientSockets = new ArrayList<>(CLIENT_COUNT);
  private final Map<String, TypedValue<InputStream>> messages = new ConcurrentHashMap<>();
  private final Map<String, List<Throwable>> exceptions = new ConcurrentHashMap<>();
  private final AtomicInteger serverSocketCounter = new AtomicInteger(0);
  private final AtomicInteger clientSocketCounter = new AtomicInteger(0);
  private final AtomicInteger closedClientSocketCounter = new AtomicInteger(0);
  private final long idleSocketTimeout = 15000;

  private BiConsumer<WebSocket, Throwable> errorCallback;
  private RetryPolicyTemplate retryPolicyTemplate = new NoRetryPolicyTemplate();
  private SimpleUnitTestSupportCustomScheduler scheduler =
      new SimpleUnitTestSupportCustomScheduler(1, Thread::new, new ThreadPoolExecutor.AbortPolicy());


  @Before
  public void before() throws Exception {
    errorCallback = (ws, t) -> exceptions.computeIfAbsent(ws.getId(), k -> new LinkedList<>()).add(t);
    handlerManager = server.addWebSocketHandler(new WebSocketHandler() {

      @Override
      public String getPath() {
        return PATH;
      }

      @Override
      public WebSocketConnectionHandler getConnectionHandler() {
        return new WebSocketConnectionHandler() {

          @Override
          public String getSocketId(WebSocketRequest request) {
            return "" + serverSocketCounter.addAndGet(1);
          }

          @Override
          public void onConnect(WebSocket socket, WebSocketRequest request) throws WebSocketConnectionRejectedException {
            serverSockets.add(socket);
          }

          @Override
          public void onClose(WebSocket socket, WebSocketRequest request, WebSocketCloseCode closeCode, String reason) {}
        };
      }

      @Override
      public WebSocketMessageHandler getMessageHandler() {
        return message -> messages.put(message.getSocket().getId(), message.getContent());
      }

      @Override
      public long getIdleSocketTimeoutMills() {
        return idleSocketTimeout;
      }
    });


    handlerManager.start();

    for (int i = 0; i < CLIENT_COUNT; i++) {
      connect(new WebSocketCallback() {

        @Override
        public void onConnect(WebSocket webSocket) {
          clientSockets.add(webSocket);
        }

        @Override
        public void onClose(WebSocket webSocket, WebSocketCloseCode code, String reason) {
          closedClientSocketCounter.addAndGet(1);
        }

        @Override
        public void onMessage(WebSocket webSocket, TypedValue<InputStream> content) {}
      });
    }

    assertThat(clientSockets, hasSize(CLIENT_COUNT));
    assertThat(serverSockets, hasSize(CLIENT_COUNT));
  }

  @After
  @Override
  public void after() {
    super.after();
    scheduler.stop();
  }

  @Test
  public void broadcastToEmptyCollection() {
    BiConsumer<WebSocket, Throwable> errorCallback = mock(BiConsumer.class);

    service.newWebSocketBroadcaster().broadcast(emptyList(), new TypedValue<>(null, INPUT_STREAM), errorCallback);
    checkNot(TIMEOUT_MILLIS, POLL_DELAY_MILLIS, () -> !messages.isEmpty());

    verifyZeroInteractions(errorCallback);
  }

  @Test
  public void broadcastShortTextMessage() throws Exception {
    assertTextBroadcast(SHORT_TEXT_MESSAGE);
  }

  @Test
  public void broadcastLargeTextMessage() throws Exception {
    assertTextBroadcast(randomAlphanumeric(LARGE_MESSAGE_SIZE));
  }

  @Test
  public void broadcastShortBinaryMessage() throws Exception {
    assertBinaryBroadcast(nextBytes(30));
  }

  @Test
  public void broadcastLargeBinaryMessage() throws Exception {
    assertBinaryBroadcast(nextBytes(LARGE_MESSAGE_SIZE));
  }

  @Test
  public void broadcastToReconnectableSockets() throws Exception {
    check(idleSocketTimeout + 1000, 5000, () -> closedClientSocketCounter.get() > 0);
    retryPolicyTemplate = new SimpleRetryPolicyTemplate(1000, 2);

    assertTextBroadcast(SHORT_TEXT_MESSAGE);
  }

  private void assertTextBroadcast(String text) throws Exception {
    TypedValue<InputStream> content = new TypedValue<>(new ByteArrayInputStream(text.getBytes()), TEXT_STRING);
    assertBroadcast(content, message -> {
      assertThat(message.getDataType().getType(), equalTo(InputStream.class));
      assertThat(message.getDataType().getMediaType(), equalTo(TEXT_STRING.getMediaType()));
    });
  }

  private void assertBinaryBroadcast(byte[] data) throws Exception {
    TypedValue<InputStream> content = new TypedValue<>(new ByteArrayInputStream(data), INPUT_STREAM);
    assertBroadcast(content, message -> assertThat(message.getDataType().getMediaType(), equalTo(BINARY)));
  }

  private void assertBroadcast(TypedValue<InputStream> content, CheckedConsumer<TypedValue<InputStream>> mediaTypeAssertion)
      throws Exception {
    content.getValue().mark(MAX_VALUE);
    service.newWebSocketBroadcaster().broadcast(new ArrayList<>(clientSockets), content, errorCallback, retryPolicyTemplate,
                                                scheduler);

    probe(TIMEOUT_MILLIS, POLL_DELAY_MILLIS, () -> messages.size() == CLIENT_COUNT);

    for (WebSocket socket : serverSockets) {
      if (!socket.isConnected()) {
        continue;
      }

      TypedValue<InputStream> message = messages.get(socket.getId());
      assertThat(message, is(notNullValue()));

      mediaTypeAssertion.accept(message);
      content.getValue().reset();
      assertThat(toByteArray(message.getValue()), equalTo(toByteArray(content.getValue())));
    }
  }

  private void connect(WebSocketCallback callback) throws Exception {
    probe(TIMEOUT_MILLIS, POLL_DELAY_MILLIS, () -> {
      String uri = format("ws://localhost:%d%s", port.getNumber(), PATH);
      client.openWebSocket(HttpRequest.builder().uri(uri)
          .method("GET")
          .build(),
                           HttpRequestOptions.builder()
                               .responseTimeout(3000)
                               .followsRedirect(true)
                               .authentication(null)
                               .build(),
                           "" + clientSocketCounter.addAndGet(1),
                           callback)
          .get();

      return true;
    });
  }
}
