/*
 * (c) 2003-2021 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package com.mulesoft.service.http.impl.functional.ws;

import static java.util.concurrent.TimeUnit.SECONDS;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.sameInstance;
import static org.hamcrest.Matchers.hasSize;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mule.runtime.http.api.ws.WebSocketProtocol.WS;

import org.mule.runtime.api.scheduler.Scheduler;
import org.mule.runtime.core.api.retry.policy.RetryPolicyTemplate;
import org.mule.runtime.http.api.ws.WebSocket;

import com.mulesoft.service.http.impl.service.client.ws.OutboundWebSocket;
import com.mulesoft.service.http.impl.service.client.ws.reconnect.OutboundWebSocketReconnectionHandler;

import java.net.URI;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;

import com.ning.http.client.AsyncHttpProviderConfig;
import com.ning.http.client.providers.grizzly.websocket.GrizzlyWebSocketAdapter;
import org.glassfish.grizzly.websockets.ProtocolHandler;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;

@RunWith(MockitoJUnitRunner.class)
public class OutboundSocketReconnectionTestCase {

  @Mock
  private AsyncHttpProviderConfig httpProviderConfig;

  @Mock
  private ProtocolHandler protocolHandler;

  @Mock
  private OutboundWebSocketReconnectionHandler reconnectionHandler;

  @Mock
  private RetryPolicyTemplate retryPolicyTemplate;

  @Mock
  private Scheduler scheduler;

  private CompletableFuture<WebSocket> reconnectionFuture = new CompletableFuture<>();
  private GrizzlyWebSocketAdapter adapter;
  private OutboundWebSocket socket;

  @Before
  public void before() {
    adapter = GrizzlyWebSocketAdapter.newInstance(httpProviderConfig, protocolHandler);
    socket = new OutboundWebSocket("id",
                                   URI.create("http://mulesoft.com"),
                                   WS,
                                   adapter,
                                   reconnectionHandler);
    when(reconnectionHandler.reconnect(socket, retryPolicyTemplate, scheduler)).thenReturn(reconnectionFuture);
  }

  @Test
  public void reconnectsTwiceSerially() throws Exception {
    CompletableFuture firstFuture = socket.reconnect(retryPolicyTemplate, scheduler);
    verify(reconnectionHandler).reconnect(socket, retryPolicyTemplate, scheduler);

    WebSocket newSocket = mock(WebSocket.class);
    reconnectionFuture.complete(newSocket);

    assertThat(firstFuture.get(), is(sameInstance(newSocket)));

    reset(reconnectionHandler);
    reconnectionFuture = new CompletableFuture<>();
    when(reconnectionHandler.reconnect(socket, retryPolicyTemplate, scheduler)).thenReturn(reconnectionFuture);

    CompletableFuture<WebSocket> secondFuture = socket.reconnect(retryPolicyTemplate, scheduler);
    assertThat(secondFuture, is(not(sameInstance(reconnectionFuture))));

    verify(reconnectionHandler).reconnect(socket, retryPolicyTemplate, scheduler);
  }

  @Test
  public void concurrentReconnectionReturnSameFuture() throws Exception {
    List<CompletableFuture<WebSocket>> futures = new CopyOnWriteArrayList<>();
    int top = 5;
    CountDownLatch latch = new CountDownLatch(5);
    for (short i = 0; i < top; i++) {
      new Thread(() -> {
        futures.add(socket.reconnect(retryPolicyTemplate, scheduler));
        latch.countDown();
      }).start();
    }

    assertThat(latch.await(5, SECONDS), is(true));

    assertThat(futures, hasSize(top));
    assertThat(futures.stream().allMatch(f -> f == futures.get(0)), is(true));
  }
}
