/*
 * (c) 2003-2021 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package com.mulesoft.service.http.impl.functional.ws;

import static java.lang.String.format;
import static java.lang.Thread.sleep;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.apache.commons.lang3.RandomStringUtils.randomAlphanumeric;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.Matchers.hasSize;
import static org.junit.Assert.assertThat;
import static org.junit.rules.ExpectedException.none;
import static org.mule.runtime.api.metadata.DataType.TEXT_STRING;
import static org.mule.runtime.api.metadata.MediaType.TEXT;
import static org.mule.runtime.http.api.ws.WebSocket.WebSocketType.INBOUND;
import static org.mule.runtime.http.api.ws.WebSocketCloseCode.NORMAL_CLOSURE;
import static org.mule.tck.probe.PollingProber.check;
import static org.mule.tck.probe.PollingProber.probe;

import org.mule.runtime.api.metadata.TypedValue;
import org.mule.runtime.api.util.Pair;
import org.mule.runtime.api.util.Reference;
import org.mule.runtime.api.util.concurrent.Latch;
import org.mule.runtime.core.api.retry.policy.SimpleRetryPolicyTemplate;
import org.mule.runtime.http.api.client.HttpRequestOptions;
import org.mule.runtime.http.api.client.ws.WebSocketCallback;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.server.ws.WebSocketConnectionHandler;
import org.mule.runtime.http.api.server.ws.WebSocketHandler;
import org.mule.runtime.http.api.server.ws.WebSocketMessage;
import org.mule.runtime.http.api.server.ws.WebSocketMessageHandler;
import org.mule.runtime.http.api.server.ws.WebSocketRequest;
import org.mule.runtime.http.api.ws.WebSocket;
import org.mule.runtime.http.api.ws.WebSocketCloseCode;
import org.mule.runtime.http.api.ws.exception.WebSocketClosedException;
import org.mule.runtime.http.api.ws.exception.WebSocketConnectionException;
import org.mule.tck.SimpleUnitTestSupportCustomScheduler;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ThreadPoolExecutor.AbortPolicy;

import org.apache.commons.io.IOUtils;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

public class FullDuplexWebSocketTestCase extends AbstractWebSocketTestCase {

  private static final String PATH = "/chat";
  private static final String SERVER_CONNECTION_ID = "serverConnection";
  private static final String CLIENT_SOCKET_ID = "clientConnection";
  private static final String SHORT_TEXT_MESSAGE = "Hello There!";
  private static final String SHORT_RESPONSE_MESSAGE = "ACK";
  private static final int PROBE_TIMEOUT = 5000;
  private static final int PROBE_MILLIS = 100;

  @Rule
  public ExpectedException expectedException = none();

  @Test
  public void fullDuplex() throws Exception {
    fullDuplex(PATH, new TypedValue<>(SHORT_TEXT_MESSAGE.getBytes(), TEXT_STRING),
               new TypedValue<>(SHORT_RESPONSE_MESSAGE.getBytes(), TEXT_STRING));
  }

  @Test
  public void fullDuplexWithStreaming() throws Exception {
    final int length = (16 * 1024) + 54;
    final String outboundMessage = randomAlphanumeric(length);
    final String responseMessage = randomAlphanumeric(length);

    fullDuplex(PATH, new TypedValue<>(outboundMessage.getBytes(), TEXT_STRING),
               new TypedValue<>(responseMessage.getBytes(), TEXT_STRING));
  }

  @Test
  public void idleInboundSocketTimeout() throws Exception {
    Latch messageLatch = new Latch();
    Reference<Boolean> serverSocketConnected = new Reference<>(false);
    Reference<Boolean> serverSocketClosed = new Reference<>(false);
    Reference<WebSocket> serverSocketReference = new Reference<>();

    WebSocketConnectionHandler connectionHandler = new WebSocketConnectionHandler() {

      @Override
      public String getSocketId(WebSocketRequest request) {
        return "idle";
      }

      @Override
      public void onConnect(WebSocket socket, WebSocketRequest request) {
        serverSocketConnected.set(true);
        serverSocketReference.set(socket);
      }

      @Override
      public void onClose(WebSocket socket, WebSocketRequest request, WebSocketCloseCode closeCode, String reason) {
        serverSocketClosed.set(true);
      }
    };

    Reference<WebSocketMessage> messageHolder = new Reference<>();
    WebSocketMessageHandler messageHandler = message -> {
      messageHolder.set(message);
      messageLatch.release();
    };

    final int idleTimeoutMillis = 5000;
    WebSocketHandler handler = new TestWebSocketHandler(PATH, connectionHandler, messageHandler, idleTimeoutMillis);
    handlerManager = server.addWebSocketHandler(handler);
    handlerManager.start();

    TestWebSocketCallback callback = new TestWebSocketCallback();
    connect(callback);

    try {
      check(1000, 100, () -> {
        assertThat(callback.isConnected(), is(true));
        assertThat(serverSocketConnected.get(), is(true));

        return true;
      });

      check(idleTimeoutMillis + 1000, 500, () -> {
        assertThat(serverSocketClosed.get(), is(false));
        return true;
      });
    } finally {
      if (serverSocketReference.get() != null) {
        serverSocketReference.get().close(NORMAL_CLOSURE, "").get();
      }
    }
  }

  @Test
  public void remotelyClosedSocketThrowsConnectionException() throws Exception {
    Latch messageLatch = new Latch();
    Reference<Boolean> serverSocketConnected = new Reference<>(false);
    Reference<Boolean> serverSocketClosed = new Reference<>(false);
    Reference<WebSocket> serverSocketReference = new Reference<>();

    WebSocketConnectionHandler connectionHandler = new WebSocketConnectionHandler() {

      @Override
      public String getSocketId(WebSocketRequest request) {
        return "idle";
      }

      @Override
      public void onConnect(WebSocket socket, WebSocketRequest request) {
        serverSocketConnected.set(true);
        serverSocketReference.set(socket);
      }

      @Override
      public void onClose(WebSocket socket, WebSocketRequest request, WebSocketCloseCode closeCode, String reason) {
        serverSocketClosed.set(true);
      }
    };

    Reference<WebSocketMessage> messageHolder = new Reference<>();
    WebSocketMessageHandler messageHandler = message -> {
      messageHolder.set(message);
      messageLatch.release();
    };

    final int idleTimeoutMillis = 500;
    WebSocketHandler handler = new TestWebSocketHandler(PATH, connectionHandler, messageHandler, idleTimeoutMillis);
    handlerManager = server.addWebSocketHandler(handler);
    handlerManager.start();

    TestWebSocketCallback callback = new TestWebSocketCallback();
    WebSocket client = connect(callback);

    try {
      check(idleTimeoutMillis, 10, () -> {
        assertThat(callback.isConnected(), is(true));
        assertThat(serverSocketConnected.get(), is(true));

        return true;
      });

      check(idleTimeoutMillis, 500, () -> {
        assertThat(serverSocketClosed.get(), is(false));
        return true;
      });

      sleep(5000);
      Reference<Throwable> error = new Reference<>();
      Latch closedLatch = new Latch();

      client.send(new ByteArrayInputStream("Hello closed!".getBytes()), TEXT).whenComplete((v, e) -> {
        if (e != null) {
          error.set(e);
          closedLatch.release();
        }
      });

      closedLatch.await(10, SECONDS);
      assertThat(error.get(), is(instanceOf(WebSocketConnectionException.class)));
    } finally {
      if (serverSocketReference.get() != null) {
        serverSocketReference.get().close(NORMAL_CLOSURE, "").get();
      }
    }
  }

  @Test
  public void remotelyClosedSocketReconnects() throws Exception {
    Latch messageLatch = new Latch();
    Reference<Boolean> serverSocketConnected = new Reference<>(false);
    Reference<Boolean> serverSocketClosed = new Reference<>(false);
    Reference<WebSocket> serverSocketReference = new Reference<>();

    WebSocketConnectionHandler connectionHandler = new WebSocketConnectionHandler() {

      @Override
      public String getSocketId(WebSocketRequest request) {
        return "idle";
      }

      @Override
      public void onConnect(WebSocket socket, WebSocketRequest request) {
        serverSocketConnected.set(true);
        serverSocketReference.set(socket);
      }

      @Override
      public void onClose(WebSocket socket, WebSocketRequest request, WebSocketCloseCode closeCode, String reason) {
        serverSocketClosed.set(true);
      }
    };

    Reference<WebSocketMessage> messageHolder = new Reference<>();
    WebSocketMessageHandler messageHandler = message -> {
      messageHolder.set(message);
      messageLatch.release();
    };

    final int idleTimeoutMillis = 500;
    WebSocketHandler handler = new TestWebSocketHandler(PATH, connectionHandler, messageHandler, idleTimeoutMillis);
    handlerManager = server.addWebSocketHandler(handler);
    handlerManager.start();

    TestWebSocketCallback callback = new TestWebSocketCallback();
    WebSocket client = connect(callback);

    SimpleUnitTestSupportCustomScheduler scheduler = new SimpleUnitTestSupportCustomScheduler(1, Thread::new, new AbortPolicy());

    try {
      check(idleTimeoutMillis, 10, () -> {
        assertThat(callback.isConnected(), is(true));
        assertThat(serverSocketConnected.get(), is(true));

        return true;
      });

      check(idleTimeoutMillis, 500, () -> {
        assertThat(serverSocketClosed.get(), is(false));
        return true;
      });

      sleep(5000);
      Reference<Throwable> error = new Reference<>();
      Latch closedLatch = new Latch();

      client.send(new ByteArrayInputStream("Hello closed!".getBytes()), TEXT).whenComplete((v, e) -> {
        if (e != null) {
          error.set(e);
          closedLatch.release();
        }
      });

      closedLatch.await(10, SECONDS);
      assertThat(error.get(), is(instanceOf(WebSocketConnectionException.class)));

      assertThat(client.supportsReconnection(), is(true));

      client = client.reconnect(new SimpleRetryPolicyTemplate(500, 2), scheduler).get();

      Latch reconnectedLatch = new Latch();
      error.set(null);

      String reconnectionMessage = "Hello reconnected!";
      client.send(new ByteArrayInputStream(reconnectionMessage.getBytes()), TEXT).whenComplete((v, e) -> {
        if (e != null) {
          error.set(e);
        }
        reconnectedLatch.release();
      });

      try {
        assertThat(reconnectedLatch.await(5, SECONDS), is(true));
        assertThat(error.get(), is(nullValue()));

        check(5000, 500, () -> {
          WebSocketMessage message = messageHolder.get();
          assertThat(message, is(notNullValue()));
          assertThat(IOUtils.toString(message.getContent().getValue()), equalTo(reconnectionMessage));
          return true;
        });

      } finally {
        client.close(NORMAL_CLOSURE, "");
      }
    } finally {
      scheduler.stop();
      if (serverSocketReference.get() != null) {
        serverSocketReference.get().close(NORMAL_CLOSURE, "").get();
      }
    }
  }

  @Test
  public void alreadyClosedSocket() throws Exception {
    Latch messageLatch = new Latch();
    Reference<Boolean> serverSocketConnected = new Reference<>(false);
    Reference<Boolean> serverSocketClosed = new Reference<>(false);
    Reference<WebSocket> serverSocketReference = new Reference<>();

    WebSocketConnectionHandler connectionHandler = new WebSocketConnectionHandler() {

      @Override
      public String getSocketId(WebSocketRequest request) {
        return "idle";
      }

      @Override
      public void onConnect(WebSocket socket, WebSocketRequest request) {
        serverSocketConnected.set(true);
        serverSocketReference.set(socket);
      }

      @Override
      public void onClose(WebSocket socket, WebSocketRequest request, WebSocketCloseCode closeCode, String reason) {
        serverSocketClosed.set(true);
      }
    };

    Reference<WebSocketMessage> messageHolder = new Reference<>();
    WebSocketMessageHandler messageHandler = message -> {
      messageHolder.set(message);
      messageLatch.release();
    };

    final int idleTimeoutMillis = 5000;
    WebSocketHandler handler = new TestWebSocketHandler(PATH, connectionHandler, messageHandler, idleTimeoutMillis);
    handlerManager = server.addWebSocketHandler(handler);
    handlerManager.start();

    TestWebSocketCallback callback = new TestWebSocketCallback();
    WebSocket client = connect(callback);

    try {
      check(1000, 100, () -> {
        assertThat(callback.isConnected(), is(true));
        assertThat(serverSocketConnected.get(), is(true));

        return true;
      });

      client.close(NORMAL_CLOSURE, "").get();

      expectedException.expect(ExecutionException.class);
      expectedException.expectCause(is(instanceOf(WebSocketClosedException.class)));
      client.send(new ByteArrayInputStream("I'm closed!".getBytes()), TEXT).get();
    } finally {
      if (serverSocketReference.get() != null) {
        serverSocketReference.get().close(NORMAL_CLOSURE, "").get();
      }
    }
  }

  private void fullDuplex(String path, TypedValue<byte[]> outboundMessage, TypedValue<byte[]> responseMessage) throws Exception {
    Latch messageLatch = new Latch();
    Reference<Boolean> serverSocketConnected = new Reference<>(false);
    Reference<Boolean> serverSocketClosed = new Reference<>(false);
    Reference<WebSocket> serverSocketReference = new Reference<>();

    WebSocketConnectionHandler connectionHandler = new WebSocketConnectionHandler() {

      @Override
      public String getSocketId(WebSocketRequest request) {
        return request.getQueryParams().get("id");
      }

      @Override
      public void onConnect(WebSocket socket, WebSocketRequest request) {
        serverSocketConnected.set(true);
        serverSocketReference.set(socket);
      }

      @Override
      public void onClose(WebSocket socket, WebSocketRequest request, WebSocketCloseCode closeCode, String reason) {
        serverSocketClosed.set(true);
      }
    };

    Reference<WebSocketMessage> messageHolder = new Reference<>();
    WebSocketMessageHandler messageHandler = message -> {
      messageHolder.set(message);
      messageLatch.release();
    };

    WebSocketHandler handler = new TestWebSocketHandler(path, connectionHandler, messageHandler);
    handlerManager = server.addWebSocketHandler(handler);
    handlerManager.start();

    TestWebSocketCallback callback = new TestWebSocketCallback();
    WebSocket client = connect(callback);

    client.send(new ByteArrayInputStream(outboundMessage.getValue()), outboundMessage.getDataType().getMediaType()).get();

    messageLatch.await();

    WebSocket serverSocket = serverSocketReference.get();
    try {
      assertThat(serverSocketConnected.get(), is(true));
      assertThat(serverSocket.getUri().getPath(), equalTo(path));

      assertThat(serverSocket.getId(), equalTo(SERVER_CONNECTION_ID));
      assertThat(serverSocket.getType(), is(INBOUND));

      WebSocketMessage message = messageHolder.get();
      assertThat(message.getSocket().getId(), equalTo(SERVER_CONNECTION_ID));
      assertThat(IOUtils.toString(message.getContent().getValue()), equalTo(new String(outboundMessage.getValue())));

      assertThat(callback.isConnected(), is(true));
      assertThat(callback.isClosed(), is(false));

      serverSocket.send(new ByteArrayInputStream(responseMessage.getValue()), responseMessage.getDataType().getMediaType()).get();

      check(PROBE_TIMEOUT, PROBE_MILLIS, () -> !callback.getMessages().isEmpty());
      assertThat(callback.getMessages(), hasSize(1));
      Pair<WebSocket, TypedValue<InputStream>> response = callback.getMessages().get(0);

      assertThat(response.getFirst().getId(), equalTo(CLIENT_SOCKET_ID));
      assertThat(IOUtils.toString(response.getSecond().getValue()), equalTo(new String(responseMessage.getValue())));
      assertThat(serverSocketClosed.get(), is(false));
    } finally {
      serverSocket.close(NORMAL_CLOSURE, "").get();
    }

    check(PROBE_TIMEOUT, PROBE_MILLIS, serverSocketClosed::get);
  }

  private WebSocket connect(WebSocketCallback callback) throws Exception {
    Reference<WebSocket> socket = new Reference<>();
    probe(5000, 1000, () -> {
      String uri = format("ws://localhost:%d%s", port.getNumber(), PATH);
      socket.set(client.openWebSocket(HttpRequest.builder().uri(uri)
          .method("GET")
          .addQueryParam("id", SERVER_CONNECTION_ID)
          .build(),
                                      HttpRequestOptions.builder()
                                          .responseTimeout(3000)
                                          .followsRedirect(true)
                                          .authentication(null)
                                          .build(),
                                      CLIENT_SOCKET_ID,
                                      callback)
          .get());

      return socket.get() != null;
    });

    return socket.get();
  }

  private class TestWebSocketCallback implements WebSocketCallback {

    private boolean connected, closed = false;
    private List<Pair<WebSocket, TypedValue<InputStream>>> messages = new CopyOnWriteArrayList<>();

    @Override
    public void onConnect(WebSocket webSocket) {
      connected = true;
    }

    @Override
    public void onClose(WebSocket webSocket, WebSocketCloseCode code, String reason) {
      closed = true;
    }

    @Override
    public void onMessage(WebSocket webSocket, TypedValue<InputStream> content) {
      messages.add(new Pair<>(webSocket, content));
    }

    public boolean isConnected() {
      return connected;
    }

    public boolean isClosed() {
      return closed;
    }

    public List<Pair<WebSocket, TypedValue<InputStream>>> getMessages() {
      return messages;
    }
  }


  private class TestWebSocketHandler implements WebSocketHandler {

    private final String path;
    private final WebSocketConnectionHandler connectionHandler;
    private final WebSocketMessageHandler messageHandler;
    private final long idleTimeoutMillis;

    public TestWebSocketHandler(String path,
                                WebSocketConnectionHandler connectionHandler,
                                WebSocketMessageHandler messageHandler) {
      this(path, connectionHandler, messageHandler, 30000);
    }

    public TestWebSocketHandler(String path,
                                WebSocketConnectionHandler connectionHandler,
                                WebSocketMessageHandler messageHandler,
                                long idleTimeoutMillis) {
      this.path = path;
      this.connectionHandler = connectionHandler;
      this.messageHandler = messageHandler;
      this.idleTimeoutMillis = idleTimeoutMillis;
    }

    @Override
    public String getPath() {
      return path;
    }

    @Override
    public WebSocketConnectionHandler getConnectionHandler() {
      return connectionHandler;
    }

    @Override
    public WebSocketMessageHandler getMessageHandler() {
      return messageHandler;
    }

    @Override
    public long getIdleSocketTimeoutMills() {
      return idleTimeoutMillis;
    }
  }
}
