/*
 * (c) 2003-2018 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package com.mulesoft.service.http.impl.service.ws;

import static com.mulesoft.service.http.impl.service.ws.WebSocketUtils.DEFAULT_DATA_FRAME_SIZE;
import static com.mulesoft.service.http.impl.service.ws.WebSocketUtils.failedFuture;
import static java.lang.System.arraycopy;
import static java.util.Collections.newSetFromMap;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.stream.Collectors.joining;
import static org.glassfish.grizzly.utils.Futures.completable;
import static org.mule.runtime.api.i18n.I18nMessageFactory.createStaticMessage;
import static org.mule.runtime.api.metadata.MediaTypeUtils.isStringRepresentable;
import static org.slf4j.LoggerFactory.getLogger;
import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.api.metadata.TypedValue;
import org.mule.runtime.api.util.concurrent.Latch;
import org.mule.runtime.http.api.ws.WebSocket;
import org.mule.runtime.http.api.ws.WebSocketBroadcaster;

import com.mulesoft.service.http.impl.service.client.ws.InboundWebSocket;
import com.mulesoft.service.http.impl.service.client.ws.OutboundWebSocket;

import java.io.InputStream;
import java.util.Collection;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;

import org.glassfish.grizzly.websockets.SimpleWebSocket;
import org.slf4j.Logger;

/**
 * {@link WebSocketBroadcaster} based on Grizzly
 *
 * @since 1.2.0
 */
public class GrizzlyWebSocketBroadcaster implements WebSocketBroadcaster {

  private static final Logger LOGGER = getLogger(GrizzlyWebSocketBroadcaster.class);

  @Override
  public CompletableFuture<Void> broadcast(Collection<WebSocket> sockets, TypedValue<InputStream> content,
                                           BiConsumer<WebSocket, Throwable> errorCallback) {
    return new BroadcastAction(sockets, content, errorCallback).broadcast();
  }

  private class BroadcastAction {

    private final Collection<WebSocket> sockets;
    private final TypedValue<InputStream> content;
    private final BiConsumer<WebSocket, Throwable> errorCallback;
    private final Set<String> failedSockets = newSetFromMap(new ConcurrentHashMap<>());
    private final FrameFactory frameFactory;
    private final int frameSize = DEFAULT_DATA_FRAME_SIZE;
    private boolean streaming = false;

    public BroadcastAction(Collection<WebSocket> sockets,
                           TypedValue<InputStream> content,
                           BiConsumer<WebSocket, Throwable> errorCallback) {
      this.sockets = sockets;
      this.content = content;
      this.errorCallback = errorCallback;

      frameFactory = isStringRepresentable(content.getDataType().getMediaType())
          ? new TextFrameFactory()
          : new BinaryFrameFactory();
    }

    private CompletableFuture<Void> broadcast() {
      byte[] readBuffer = new byte[frameSize];
      byte[] writeBuffer = new byte[frameSize];
      int read;
      int write = 0;
      CompletableFuture<Void> composedFuture = null;
      Latch latch = new Latch();

      InputStream value = content.getValue();

      try {
        while (!allSocketsFailed() && (read = value.read(readBuffer, 0, readBuffer.length)) != -1) {
          if (write > 0) {
            streaming = true;
            int len = write;
            CompletableFuture<Void> future =
                doBroadcast(sockets, writeBuffer, (ws, data) -> frameFactory.asFragment(ws, data, 0, len, false))
                    .whenComplete((v, e) -> {
                      if (e != null) {
                        latch.release();
                      }
                    });

            composedFuture = composedFuture == null ? future : composedFuture.thenCompose(v -> future);
          }
          arraycopy(readBuffer, 0, writeBuffer, 0, read);
          write = read;
        }

        if (composedFuture != null) {
          composedFuture.whenComplete((v, e) -> latch.release());
          latch.await();
        }

        if (write == 0 || allSocketsFailed()) {
          logBroadcastCompleted();
          return completedFuture(null);
        }

        // because of a bug in grizzly we need to create a byte array with the exact length
        if (write < writeBuffer.length) {
          byte[] exactSize = writeBuffer;
          writeBuffer = new byte[write];
          arraycopy(exactSize, 0, writeBuffer, 0, write);
        }

        return doBroadcast(sockets, writeBuffer, (ws, data) -> streaming
            ? frameFactory.asFragment(ws, data, true)
            : frameFactory.asFrame(ws, data)).whenComplete((v, e) -> logBroadcastCompleted());

      } catch (Throwable t) {
        if (LOGGER.isDebugEnabled()) {
          LOGGER.debug("Could not broadcast message: " + t.getMessage() + ". Target sockets were: " + socketsToString(), t);
        }

        return failedFuture(new MuleRuntimeException(createStaticMessage("Could not perform broadcast: " + t.getMessage()), t));
      }
    }

    private CompletableFuture<Void> doBroadcast(Collection<WebSocket> sockets, byte[] data,
                                                BiFunction<SimpleWebSocket, byte[], byte[]> frameFactory) {
      byte[] frame = null;
      final int totalSockets = sockets.size();
      final AtomicInteger socketCount = new AtomicInteger(0);
      final CompletableFuture<Void> sink = new CompletableFuture<>();

      for (WebSocket socket : sockets) {
        try {
          SimpleWebSocket ws = asActiveWebSocket(socket);
          if (ws == null) {
            incrementAndComplete(socketCount, totalSockets, sink);
            continue;
          }

          if (frame == null) {
            frame = frameFactory.apply(ws, data);
          }

          completable(ws.sendRaw(frame)).whenComplete((v, e) -> {
            incrementAndComplete(socketCount, totalSockets, sink);
            if (e != null) {
              handleSocketError(socket, e);
            }
          });
        } catch (Throwable t) {
          incrementAndComplete(socketCount, totalSockets, sink);
          handleSocketError(socket, t);
        }
      }

      return sink;
    }

    private void incrementAndComplete(AtomicInteger count, int top, CompletableFuture<Void> sink) {
      if (count.addAndGet(1) >= top) {
        sink.complete(null);
      }
    }

    private void handleSocketError(WebSocket socket, Throwable e) {
      if (LOGGER.isDebugEnabled()) {
        LOGGER.debug("Found exception while broadcasting to WebSocket. " + e.getMessage() + ". Socket was: " + socket.toString(),
                     e);
      }

      failedSockets.add(socket.getId());
      errorCallback.accept(socket, e);
    }

    private SimpleWebSocket asActiveWebSocket(WebSocket socket) {
      if (failedSockets.contains(socket.getId())) {
        return null;
      }

      SimpleWebSocket ws;
      if (socket instanceof InboundWebSocket) {
        ws = (SimpleWebSocket) socket;
      } else if (socket instanceof OutboundWebSocket) {
        ws = (SimpleWebSocket) ((OutboundWebSocket) socket).getGrizzlyWebSocket();
      } else {
        throw new IllegalArgumentException("Invalid socket of type: " + socket.getClass().getName());
      }

      return ws.isConnected() ? ws : null;
    }

    private boolean allSocketsFailed() {
      return failedSockets.size() >= sockets.size();
    }

    private void logBroadcastCompleted() {
      if (LOGGER.isDebugEnabled()) {
        String recipients = "Recipient list was: " + socketsToString();
        if (failedSockets.isEmpty()) {
          LOGGER.debug("Broadcast successful to all target WebSockets. " + recipients);
        } else {
          String failed = failedSockets.stream().collect(joining(", "));
          LOGGER.debug("Broadcast completed, but delivery to the following WebSockets failed: {}. {}", failed, recipients);
        }
      }
    }

    private String socketsToString() {
      return sockets.stream()
          .map(WebSocket::getId)
          .collect(joining(", "));
    }
  }


  private interface FrameFactory {

    default byte[] asFrame(SimpleWebSocket socket, byte[] bytes) {
      return asFrame(socket, bytes, 0, bytes.length);
    }

    default byte[] asFragment(SimpleWebSocket socket, byte[] bytes, boolean last) {
      return asFragment(socket, bytes, 0, bytes.length, last);
    }

    byte[] asFrame(SimpleWebSocket socket, byte[] bytes, int offset, int len);

    byte[] asFragment(SimpleWebSocket socket, byte[] bytes, int offset, int len, boolean last);
  }


  private class TextFrameFactory implements FrameFactory {

    @Override
    public byte[] asFrame(SimpleWebSocket socket, byte[] bytes, int offset, int len) {
      return socket.toRawData(new String(bytes, offset, len));
    }

    @Override
    public byte[] asFragment(SimpleWebSocket socket, byte[] bytes, int offset, int len, boolean last) {
      return socket.toRawData(new String(bytes, offset, len), last);
    }
  }


  private class BinaryFrameFactory implements FrameFactory {

    @Override
    public byte[] asFrame(SimpleWebSocket socket, byte[] bytes, int offset, int len) {
      return socket.toRawData(slice(bytes, offset, len));
    }

    @Override
    public byte[] asFragment(SimpleWebSocket socket, byte[] bytes, int offset, int len, boolean last) {
      return socket.toRawData(slice(bytes, offset, len), last);
    }

    private byte[] slice(byte[] bytes, int offset, int len) {
      if (bytes.length == len) {
        return bytes;
      }

      byte[] slice = new byte[len];
      arraycopy(bytes, offset, slice, 0, len);
      return slice;
    }
  }
}
