/*
 * (c) 2003-2020 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package com.mulesoft.service.http.impl.service.client.ws;

import static com.mulesoft.service.http.impl.service.ws.WebSocketUtils.asVoid;
import static com.mulesoft.service.http.impl.service.ws.WebSocketUtils.failedFuture;
import static com.mulesoft.service.http.impl.service.ws.WebSocketUtils.mapWsException;
import static com.mulesoft.service.http.impl.service.ws.WebSocketUtils.streamInDataFrames;
import static java.lang.System.arraycopy;
import static java.util.Collections.unmodifiableList;
import static org.mule.runtime.api.metadata.MediaTypeUtils.isStringRepresentable;
import static org.mule.runtime.http.api.ws.WebSocket.WebSocketType.OUTBOUND;

import org.mule.runtime.api.metadata.MediaType;
import org.mule.runtime.api.scheduler.Scheduler;
import org.mule.runtime.core.api.retry.policy.RetryPolicyTemplate;
import org.mule.runtime.http.api.ws.WebSocket;
import org.mule.runtime.http.api.ws.WebSocketCloseCode;
import org.mule.runtime.http.api.ws.WebSocketProtocol;
import org.mule.runtime.http.api.ws.exception.WebSocketClosedException;
import org.mule.runtime.http.api.ws.exception.WebSocketConnectionException;

import com.mulesoft.service.http.impl.service.client.ws.reconnect.OutboundWebSocketReconnectionHandler;
import com.mulesoft.service.http.impl.service.ws.DataFrameEmitter;
import com.mulesoft.service.http.impl.service.ws.FragmentHandler;
import com.mulesoft.service.http.impl.service.ws.FragmentHandlerProvider;
import com.mulesoft.service.http.impl.service.ws.PipedFragmentHandlerProvider;

import java.io.InputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;

import com.ning.http.client.providers.grizzly.websocket.GrizzlyWebSocketAdapter;
import org.glassfish.grizzly.websockets.SimpleWebSocket;

/**
 * Grizzly based implementation of an OUTBOUND {@link WebSocket}.
 * <p>
 * Reconnection is supported on this implementation.
 *
 * @since 1.3.0
 */
public class OutboundWebSocket implements WebSocket {

  private final String id;
  private final URI uri;
  private final WebSocketProtocol protocol;
  private final Set<String> groups = new HashSet<>();
  private final GrizzlyWebSocketAdapter delegate;
  private final FragmentHandlerProvider fragmentHandlerProvider;
  private final AtomicBoolean closed = new AtomicBoolean(false);
  private final OutboundWebSocketReconnectionHandler reconnectionHandler;

  private final Lock reconnectionLock = new ReentrantLock();
  private final AtomicReference<CompletableFuture<WebSocket>> ongoingReconnection = new AtomicReference<>(null);

  public OutboundWebSocket(String id,
                           URI uri,
                           WebSocketProtocol protocol,
                           GrizzlyWebSocketAdapter delegate,
                           OutboundWebSocketReconnectionHandler reconnectionHandler) {
    this.id = id;
    this.uri = uri;
    this.delegate = delegate;
    this.protocol = protocol;
    this.reconnectionHandler = reconnectionHandler;
    fragmentHandlerProvider = new PipedFragmentHandlerProvider(id);
  }

  @Override
  public CompletableFuture<Void> send(InputStream content, MediaType mediaType) {
    if (closed.get()) {
      return failedFuture(new WebSocketClosedException(this));
    }

    if (!delegate.isOpen()) {
      return failedFuture(new WebSocketConnectionException(this));
    }

    try {
      DataFrameEmitter emitter = isStringRepresentable(mediaType)
          ? textEmitter()
          : binaryEmitter();
      return streamInDataFrames(content, emitter, t -> mapWsException(t, this));
    } catch (Throwable t) {
      return failedFuture(mapWsException(t, this));
    }
  }

  @Override
  public CompletableFuture<Void> sendFrame(byte[] frameBytes) {
    if (closed.get()) {
      return failedFuture(new WebSocketClosedException(this));
    }

    if (!delegate.isOpen()) {
      return failedFuture(new WebSocketConnectionException(this));
    }

    try {
      return asVoid(getGrizzlyWebSocket().sendRaw(frameBytes));
    } catch (Throwable t) {
      return failedFuture(mapWsException(t, this));
    }
  }

  @Override
  public byte[] toTextFrame(String data, boolean last) {
    return getGrizzlyWebSocket().toRawData(data, last);
  }

  @Override
  public byte[] toBinaryFrame(byte[] data, boolean last) {
    return getGrizzlyWebSocket().toRawData(data, last);
  }

  /**
   * {@inheritDoc}
   *
   * @return {@code true}
   */
  @Override
  public boolean supportsReconnection() {
    return true;
  }

  /**
   * {@inheritDoc}
   * This method is supported on this implementation.
   */
  @Override
  public CompletableFuture<WebSocket> reconnect(RetryPolicyTemplate retryPolicyTemplate, Scheduler scheduler) {
    reconnectionLock.lock();
    CompletableFuture<WebSocket> f;
    try {
      f = ongoingReconnection.get();
      if (f != null) {
        return f;
      }
      f = new CompletableFuture<>();
      ongoingReconnection.set(f);
    } finally {
      reconnectionLock.unlock();
    }

    final CompletableFuture<WebSocket> effectiveFuture = f;
    reconnectionHandler.reconnect(this, retryPolicyTemplate, scheduler).whenComplete((v, e) -> {
      reconnectionLock.lock();
      try {
        if (e != null) {
          effectiveFuture.completeExceptionally(e);
        } else {
          effectiveFuture.complete(v);
        }
      } finally {
        ongoingReconnection.set(null);
        reconnectionLock.unlock();
      }
    });

    return effectiveFuture;
  }

  public FragmentHandler getFragmentHandler(Consumer<FragmentHandler> newFragmentHandlerCallback) {
    return fragmentHandlerProvider.getFragmentHandler(newFragmentHandlerCallback);
  }

  private DataFrameEmitter textEmitter() {
    return new DataFrameEmitter() {

      @Override
      public CompletableFuture<Void> stream(byte[] bytes, int offset, int len, boolean last) {
        return asVoid(delegate.completableStream(new String(bytes, offset, len), last),
                      t -> mapWsException(t, OutboundWebSocket.this));
      }

      @Override
      public CompletableFuture<Void> send(byte[] bytes, int offset, int len) {
        return asVoid(delegate.completableSend(new String(bytes, offset, len)), t -> mapWsException(t, OutboundWebSocket.this));
      }
    };
  }

  private DataFrameEmitter binaryEmitter() {
    return new DataFrameEmitter() {

      @Override
      public CompletableFuture<Void> stream(byte[] bytes, int offset, int len, boolean last) {
        return asVoid(delegate.completableStream(bytes, offset, len, last), t -> mapWsException(t, OutboundWebSocket.this));
      }

      @Override
      public CompletableFuture<Void> send(byte[] bytes, int offset, int len) {
        if (offset != 0 || len != bytes.length) {
          byte[] aux = new byte[len];
          arraycopy(bytes, offset, aux, 0, len);
          bytes = aux;
        }
        return asVoid(delegate.completableSend(bytes), t -> mapWsException(t, OutboundWebSocket.this));
      }
    };
  }

  @Override
  public String getId() {
    return id;
  }

  @Override
  public List<String> getGroups() {
    synchronized (groups) {
      return unmodifiableList(new ArrayList<>(groups));
    }
  }

  @Override
  public void addGroup(String group) {
    synchronized (groups) {
      groups.add(group);
    }
  }

  @Override
  public void removeGroup(String group) {
    synchronized (groups) {
      groups.remove(group);
    }
  }

  @Override
  public CompletableFuture<Void> close(WebSocketCloseCode code, String reason) {
    try {
      closed.set(true);
      return asVoid(delegate.close(code.getProtocolCode(), reason));
    } catch (Throwable t) {
      return failedFuture(t);
    }
  }

  @Override
  public WebSocketType getType() {
    return OUTBOUND;
  }

  @Override
  public WebSocketProtocol getProtocol() {
    return protocol;
  }

  @Override
  public URI getUri() {
    return uri;
  }

  @Override
  public boolean isClosed() {
    return closed.get();
  }

  @Override
  public boolean isConnected() {
    return delegate.getGrizzlyWebSocket().isConnected();
  }

  private SimpleWebSocket getGrizzlyWebSocket() {
    return (SimpleWebSocket) delegate.getGrizzlyWebSocket();
  }

  @Override
  public boolean equals(Object obj) {
    if (obj instanceof OutboundWebSocket) {
      return id.equals(((OutboundWebSocket) obj).getId());
    }

    return false;
  }

  @Override
  public int hashCode() {
    return id.hashCode();
  }

  @Override
  public String toString() {
    return "WebSocket Id: " + id + "\nType: " + OUTBOUND + "\nURI: " + getUri();
  }
}
