/*
 * (c) 2003-2020 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package com.mulesoft.service.http.impl.service.ws;

import static com.mulesoft.service.http.impl.service.ws.WebSocketUtils.DEFAULT_DATA_FRAME_SIZE;
import static com.mulesoft.service.http.impl.service.ws.WebSocketUtils.failedFuture;
import static java.lang.String.format;
import static java.lang.System.arraycopy;
import static java.util.Collections.newSetFromMap;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.stream.Collectors.joining;
import static org.mule.runtime.api.i18n.I18nMessageFactory.createStaticMessage;
import static org.mule.runtime.api.metadata.MediaTypeUtils.isStringRepresentable;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.api.metadata.TypedValue;
import org.mule.runtime.api.scheduler.Scheduler;
import org.mule.runtime.api.util.concurrent.Latch;
import org.mule.runtime.core.api.retry.policy.NoRetryPolicyTemplate;
import org.mule.runtime.core.api.retry.policy.RetryPolicyTemplate;
import org.mule.runtime.http.api.ws.WebSocket;
import org.mule.runtime.http.api.ws.WebSocketBroadcaster;

import java.io.InputStream;
import java.util.Collection;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;

import org.slf4j.Logger;

/**
 * {@link WebSocketBroadcaster} based on Grizzly
 *
 * @since 1.2.0
 */
public class GrizzlyWebSocketBroadcaster implements WebSocketBroadcaster {

  private static final Logger LOGGER = getLogger(GrizzlyWebSocketBroadcaster.class);

  @Override
  public CompletableFuture<Void> broadcast(Collection<WebSocket> sockets,
                                           TypedValue<InputStream> content,
                                           BiConsumer<WebSocket, Throwable> errorCallback) {
    return broadcast(sockets, content, errorCallback, new NoRetryPolicyTemplate(), null);
  }

  @Override
  public CompletableFuture<Void> broadcast(Collection<WebSocket> sockets,
                                           TypedValue<InputStream> content,
                                           BiConsumer<WebSocket, Throwable> errorCallback,
                                           RetryPolicyTemplate retryPolicyTemplate,
                                           Scheduler reconnectionScheduler) {
    return new BroadcastAction(sockets, content, errorCallback, retryPolicyTemplate, reconnectionScheduler).broadcast();
  }

  private class BroadcastAction {

    private final Collection<WebSocket> sockets;
    private final TypedValue<InputStream> content;
    private final BiConsumer<WebSocket, Throwable> errorCallback;
    private final Set<String> failedSockets = newSetFromMap(new ConcurrentHashMap<>());
    private final FrameFactory frameFactory;
    private final int frameSize = DEFAULT_DATA_FRAME_SIZE;
    private final RetryPolicyTemplate retryPolicyTemplate;
    private final Scheduler reconnectionScheduler;
    private boolean streaming = false;

    public BroadcastAction(Collection<WebSocket> sockets,
                           TypedValue<InputStream> content,
                           BiConsumer<WebSocket, Throwable> errorCallback,
                           RetryPolicyTemplate retryPolicyTemplate,
                           Scheduler reconnectionScheduler) {
      this.sockets = sockets;
      this.content = content;
      this.errorCallback = errorCallback;
      this.retryPolicyTemplate = retryPolicyTemplate;
      this.reconnectionScheduler = reconnectionScheduler;

      frameFactory = isStringRepresentable(content.getDataType().getMediaType())
          ? new TextFrameFactory()
          : new BinaryFrameFactory();
    }

    private CompletableFuture<Void> broadcast() {
      byte[] readBuffer = new byte[frameSize];
      byte[] writeBuffer = new byte[frameSize];
      int read;
      int write = 0;
      CompletableFuture<Void> composedFuture = null;
      Latch latch = new Latch();

      InputStream value = content.getValue();

      try {
        while (!allSocketsFailed() && (read = value.read(readBuffer, 0, readBuffer.length)) != -1) {
          if (write > 0) {
            streaming = true;
            int len = write;
            CompletableFuture<Void> future =
                doBroadcast(sockets, writeBuffer, (ws, data) -> frameFactory.asFragment(ws, data, 0, len, false))
                    .whenComplete((v, e) -> {
                      if (e != null) {
                        latch.release();
                      }
                    });

            composedFuture = composedFuture == null ? future : composedFuture.thenCompose(v -> future);
          }
          arraycopy(readBuffer, 0, writeBuffer, 0, read);
          write = read;
        }

        if (composedFuture != null) {
          composedFuture.whenComplete((v, e) -> latch.release());
          latch.await();
        }

        if (write == 0 || allSocketsFailed()) {
          logBroadcastCompleted();
          return completedFuture(null);
        }

        // because of a bug in grizzly we need to create a byte array with the exact length
        if (write < writeBuffer.length) {
          byte[] exactSize = writeBuffer;
          writeBuffer = new byte[write];
          arraycopy(exactSize, 0, writeBuffer, 0, write);
        }

        return doBroadcast(sockets, writeBuffer, (ws, data) -> streaming
            ? frameFactory.asFragment(ws, data, true)
            : frameFactory.asFrame(ws, data)).whenComplete((v, e) -> logBroadcastCompleted());

      } catch (Throwable t) {
        if (LOGGER.isDebugEnabled()) {
          LOGGER.debug("Could not broadcast message: " + t.getMessage() + ". Target sockets were: " + socketsToString(), t);
        }

        return failedFuture(new MuleRuntimeException(createStaticMessage("Could not perform broadcast: " + t.getMessage()), t));
      }
    }

    private CompletableFuture<Void> doBroadcast(Collection<WebSocket> sockets, byte[] data,
                                                BiFunction<WebSocket, byte[], byte[]> frameFactory) {
      byte[] frame = null;
      final int totalSockets = sockets.size();
      final AtomicInteger socketCount = new AtomicInteger(0);
      final CompletableFuture<Void> sink = new CompletableFuture<>();

      for (WebSocket ws : sockets) {
        try {
          final WebSocket socket = assureConnected(ws, retryPolicyTemplate.isEnabled());
          if (socket == null) {
            incrementAndComplete(socketCount, totalSockets, sink);
            continue;
          }

          if (frame == null) {
            frame = frameFactory.apply(socket, data);
          }

          socket.sendFrame(frame).whenComplete((v, e) -> {
            incrementAndComplete(socketCount, totalSockets, sink);
            if (e != null) {
              handleSocketError(socket, e);
            }
          });
        } catch (Throwable t) {
          incrementAndComplete(socketCount, totalSockets, sink);
          handleSocketError(ws, t);
        }
      }

      return sink;
    }

    private void incrementAndComplete(AtomicInteger count, int top, CompletableFuture<Void> sink) {
      if (count.addAndGet(1) >= top) {
        sink.complete(null);
      }
    }

    private void handleSocketError(WebSocket socket, Throwable e) {
      if (LOGGER.isDebugEnabled()) {
        LOGGER.debug("Found exception while broadcasting to WebSocket. " + e.getMessage() + ". Socket was: " + socket.toString(),
                     e);
      }

      failedSockets.add(socket.getId());
      errorCallback.accept(socket, e);
    }

    private WebSocket assureConnected(WebSocket socket, boolean reconnect) {
      if (failedSockets.contains(socket.getId())) {
        return null;
      }

      if (socket.isConnected()) {
        return socket;
      } else {
        if (reconnect && !socket.isClosed()) {
          if (socket.supportsReconnection()) {
            try {
              return assureConnected(socket.reconnect(retryPolicyTemplate, reconnectionScheduler).get(), false);
            } catch (ExecutionException e) {
              LOGGER.error(format("WebSocket '%s' found exception during reconnection. Will skip from broadcast", socket.getId()),
                           e);
              return null;
            } catch (InterruptedException e) {
              LOGGER.error(format("WebSocket '%s' got interrupted during reconnection. Will skip from broadcast", socket.getId()),
                           e);
            }
          } else {
            LOGGER.info("WebSocket '{}' is not connected and is not reconnectable. Will skip from broadcast", socket.getId());
          }
        } else {
          LOGGER.info("WebSocket '{}' is not connected. Will skip from broadcast", socket.getId());
        }

        return null;
      }
    }

    private boolean allSocketsFailed() {
      return failedSockets.size() >= sockets.size();
    }

    private void logBroadcastCompleted() {
      if (LOGGER.isDebugEnabled()) {
        String recipients = "Recipient list was: " + socketsToString();
        if (failedSockets.isEmpty()) {
          LOGGER.debug("Broadcast successful to all target WebSockets. " + recipients);
        } else {
          String failed = failedSockets.stream().collect(joining(", "));
          LOGGER.debug("Broadcast completed, but delivery to the following WebSockets failed: {}. {}", failed, recipients);
        }
      }
    }

    private String socketsToString() {
      return sockets.stream()
          .map(WebSocket::getId)
          .collect(joining(", "));
    }
  }


  private interface FrameFactory {

    default byte[] asFrame(WebSocket socket, byte[] bytes) {
      return asFrame(socket, bytes, 0, bytes.length);
    }

    default byte[] asFragment(WebSocket socket, byte[] bytes, boolean last) {
      return asFragment(socket, bytes, 0, bytes.length, last);
    }

    byte[] asFrame(WebSocket socket, byte[] bytes, int offset, int len);

    byte[] asFragment(WebSocket socket, byte[] bytes, int offset, int len, boolean last);
  }


  private class TextFrameFactory implements FrameFactory {

    @Override
    public byte[] asFrame(WebSocket socket, byte[] bytes, int offset, int len) {
      return socket.toTextFrame(new String(bytes, offset, len), true);
    }

    @Override
    public byte[] asFragment(WebSocket socket, byte[] bytes, int offset, int len, boolean last) {
      return socket.toTextFrame(new String(bytes, offset, len), last);
    }
  }


  private class BinaryFrameFactory implements FrameFactory {

    @Override
    public byte[] asFrame(WebSocket socket, byte[] bytes, int offset, int len) {
      return socket.toBinaryFrame(slice(bytes, offset, len), true);
    }

    @Override
    public byte[] asFragment(WebSocket socket, byte[] bytes, int offset, int len, boolean last) {
      return socket.toBinaryFrame(slice(bytes, offset, len), last);
    }

    private byte[] slice(byte[] bytes, int offset, int len) {
      if (bytes.length == len) {
        return bytes;
      }

      byte[] slice = new byte[len];
      arraycopy(bytes, offset, slice, 0, len);
      return slice;
    }
  }
}
