/*
 * Copyright 2015-2020 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.rsocket.core;

import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.util.IllegalReferenceCountException;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.netty.util.collection.IntObjectMap;
import io.rsocket.DuplexConnection;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.frame.CancelFrameCodec;
import io.rsocket.frame.ErrorFrameCodec;
import io.rsocket.frame.FrameHeaderCodec;
import io.rsocket.frame.FrameType;
import io.rsocket.frame.PayloadFrameCodec;
import io.rsocket.frame.RequestChannelFrameCodec;
import io.rsocket.frame.RequestNFrameCodec;
import io.rsocket.frame.RequestStreamFrameCodec;
import io.rsocket.frame.decoder.PayloadDecoder;
import io.rsocket.internal.SynchronizedIntObjectHashMap;
import io.rsocket.internal.UnboundedProcessor;
import io.rsocket.lease.ResponderLeaseHandler;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.CancellationException;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.Consumer;
import java.util.function.LongConsumer;
import java.util.function.Supplier;
import org.reactivestreams.Processor;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.Disposable;
import reactor.core.Exceptions;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SignalType;
import reactor.core.publisher.UnicastProcessor;
import reactor.util.annotation.Nullable;
import reactor.util.concurrent.Queues;

/** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */
class RSocketResponder implements RSocket {
  private static final Logger LOGGER = LoggerFactory.getLogger(RSocketResponder.class);

  private static final Consumer<ReferenceCounted> DROPPED_ELEMENTS_CONSUMER =
      referenceCounted -> {
        if (referenceCounted.refCnt() > 0) {
          try {
            referenceCounted.release();
          } catch (IllegalReferenceCountException e) {
            // ignored
          }
        }
      };
  private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();

  private final DuplexConnection connection;
  private final RSocket requestHandler;

  @SuppressWarnings("deprecation")
  private final io.rsocket.ResponderRSocket responderRSocket;

  private final PayloadDecoder payloadDecoder;
  private final ResponderLeaseHandler leaseHandler;
  private final Disposable leaseHandlerDisposable;

  private volatile Throwable terminationError;
  private static final AtomicReferenceFieldUpdater<RSocketResponder, Throwable> TERMINATION_ERROR =
      AtomicReferenceFieldUpdater.newUpdater(
          RSocketResponder.class, Throwable.class, "terminationError");

  private final int mtu;
  private final int maxFrameLength;

  private final IntObjectMap<Subscription> sendingSubscriptions;
  private final IntObjectMap<Processor<Payload, Payload>> channelProcessors;

  private final UnboundedProcessor<ByteBuf> sendProcessor;
  private final ByteBufAllocator allocator;

  RSocketResponder(
      DuplexConnection connection,
      RSocket requestHandler,
      PayloadDecoder payloadDecoder,
      ResponderLeaseHandler leaseHandler,
      int mtu,
      int maxFrameLength) {
    this.connection = connection;
    this.allocator = connection.alloc();
    this.mtu = mtu;
    this.maxFrameLength = maxFrameLength;

    this.requestHandler = requestHandler;
    this.responderRSocket =
        (requestHandler instanceof io.rsocket.ResponderRSocket)
            ? (io.rsocket.ResponderRSocket) requestHandler
            : null;

    this.payloadDecoder = payloadDecoder;
    this.leaseHandler = leaseHandler;
    this.sendingSubscriptions = new SynchronizedIntObjectHashMap<>();
    this.channelProcessors = new SynchronizedIntObjectHashMap<>();

    // DO NOT Change the order here. The Send processor must be subscribed to before receiving
    // connections
    this.sendProcessor = new UnboundedProcessor<>();

    connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError);

    connection.receive().subscribe(this::handleFrame, e -> {});
    leaseHandlerDisposable = leaseHandler.send(sendProcessor::onNextPrioritized);

    this.connection
        .onClose()
        .subscribe(null, this::tryTerminateOnConnectionError, this::tryTerminateOnConnectionClose);
  }

  private void handleSendProcessorError(Throwable t) {
    cleanUpSendingSubscriptions();
    cleanUpChannelProcessors(t);
  }

  private void tryTerminateOnConnectionError(Throwable e) {
    tryTerminate(() -> e);
  }

  private void tryTerminateOnConnectionClose() {
    tryTerminate(() -> CLOSED_CHANNEL_EXCEPTION);
  }

  private void tryTerminate(Supplier<Throwable> errorSupplier) {
    if (terminationError == null) {
      Throwable e = errorSupplier.get();
      if (TERMINATION_ERROR.compareAndSet(this, null, e)) {
        cleanup(e);
      }
    }
  }

  @Override
  public Mono<Void> fireAndForget(Payload payload) {
    try {
      if (leaseHandler.useLease()) {
        return requestHandler.fireAndForget(payload);
      } else {
        payload.release();
        return Mono.error(leaseHandler.leaseError());
      }
    } catch (Throwable t) {
      return Mono.error(t);
    }
  }

  @Override
  public Mono<Payload> requestResponse(Payload payload) {
    try {
      if (leaseHandler.useLease()) {
        return requestHandler.requestResponse(payload);
      } else {
        payload.release();
        return Mono.error(leaseHandler.leaseError());
      }
    } catch (Throwable t) {
      return Mono.error(t);
    }
  }

  @Override
  public Flux<Payload> requestStream(Payload payload) {
    try {
      if (leaseHandler.useLease()) {
        return requestHandler.requestStream(payload);
      } else {
        payload.release();
        return Flux.error(leaseHandler.leaseError());
      }
    } catch (Throwable t) {
      return Flux.error(t);
    }
  }

  @Override
  public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
    try {
      if (leaseHandler.useLease()) {
        return requestHandler.requestChannel(payloads);
      } else {
        return Flux.error(leaseHandler.leaseError());
      }
    } catch (Throwable t) {
      return Flux.error(t);
    }
  }

  private Flux<Payload> requestChannel(Payload payload, Publisher<Payload> payloads) {
    try {
      if (leaseHandler.useLease()) {
        return responderRSocket.requestChannel(payload, payloads);
      } else {
        payload.release();
        return Flux.error(leaseHandler.leaseError());
      }
    } catch (Throwable t) {
      return Flux.error(t);
    }
  }

  @Override
  public Mono<Void> metadataPush(Payload payload) {
    try {
      return requestHandler.metadataPush(payload);
    } catch (Throwable t) {
      return Mono.error(t);
    }
  }

  @Override
  public void dispose() {
    tryTerminate(() -> new CancellationException("Disposed"));
  }

  @Override
  public boolean isDisposed() {
    return connection.isDisposed();
  }

  @Override
  public Mono<Void> onClose() {
    return connection.onClose();
  }

  private void cleanup(Throwable e) {
    cleanUpSendingSubscriptions();
    cleanUpChannelProcessors(e);

    connection.dispose();
    leaseHandlerDisposable.dispose();
    requestHandler.dispose();
    sendProcessor.dispose();
  }

  private synchronized void cleanUpSendingSubscriptions() {
    // Iterate explicitly to handle collisions with concurrent removals
    for (IntObjectMap.PrimitiveEntry<Subscription> entry : sendingSubscriptions.entries()) {
      try {
        entry.value().cancel();
      } catch (Throwable ex) {
        if (LOGGER.isDebugEnabled()) {
          LOGGER.debug("Dropped exception", ex);
        }
      }
    }
    sendingSubscriptions.clear();
  }

  private synchronized void cleanUpChannelProcessors(Throwable e) {
    // Iterate explicitly to handle collisions with concurrent removals
    for (IntObjectMap.PrimitiveEntry<Processor<Payload, Payload>> entry :
        channelProcessors.entries()) {
      try {
        entry.value().onError(e);
      } catch (Throwable ex) {
        if (LOGGER.isDebugEnabled()) {
          LOGGER.debug("Dropped exception", ex);
        }
      }
    }
    channelProcessors.clear();
  }

  private void handleFrame(ByteBuf frame) {
    try {
      int streamId = FrameHeaderCodec.streamId(frame);
      Subscriber<Payload> receiver;
      FrameType frameType = FrameHeaderCodec.frameType(frame);
      switch (frameType) {
        case REQUEST_FNF:
          handleFireAndForget(streamId, fireAndForget(payloadDecoder.apply(frame)));
          break;
        case REQUEST_RESPONSE:
          handleRequestResponse(streamId, requestResponse(payloadDecoder.apply(frame)));
          break;
        case CANCEL:
          handleCancelFrame(streamId);
          break;
        case REQUEST_N:
          handleRequestN(streamId, frame);
          break;
        case REQUEST_STREAM:
          long streamInitialRequestN = RequestStreamFrameCodec.initialRequestN(frame);
          Payload streamPayload = payloadDecoder.apply(frame);
          handleStream(streamId, requestStream(streamPayload), streamInitialRequestN, null);
          break;
        case REQUEST_CHANNEL:
          long channelInitialRequestN = RequestChannelFrameCodec.initialRequestN(frame);
          Payload channelPayload = payloadDecoder.apply(frame);
          handleChannel(streamId, channelPayload, channelInitialRequestN);
          break;
        case METADATA_PUSH:
          handleMetadataPush(metadataPush(payloadDecoder.apply(frame)));
          break;
        case PAYLOAD:
          // TODO: Hook in receiving socket.
          break;
        case NEXT:
          receiver = channelProcessors.get(streamId);
          if (receiver != null) {
            receiver.onNext(payloadDecoder.apply(frame));
          }
          break;
        case COMPLETE:
          receiver = channelProcessors.get(streamId);
          if (receiver != null) {
            receiver.onComplete();
          }
          break;
        case ERROR:
          receiver = channelProcessors.get(streamId);
          if (receiver != null) {
            // FIXME: when https://github.com/reactor/reactor-core/issues/2176 is resolved
            //        This is workaround to handle specific Reactor related case when
            //        onError call may not return normally
            try {
              receiver.onError(io.rsocket.exceptions.Exceptions.from(streamId, frame));
            } catch (RuntimeException e) {
              if (reactor.core.Exceptions.isBubbling(e)
                  || reactor.core.Exceptions.isErrorCallbackNotImplemented(e)) {
                if (LOGGER.isDebugEnabled()) {
                  Throwable unwrapped = reactor.core.Exceptions.unwrap(e);
                  LOGGER.debug("Unhandled dropped exception", unwrapped);
                }
              }
            }
          }
          break;
        case NEXT_COMPLETE:
          receiver = channelProcessors.get(streamId);
          if (receiver != null) {
            receiver.onNext(payloadDecoder.apply(frame));
            receiver.onComplete();
          }
          break;
        case SETUP:
          handleError(streamId, new IllegalStateException("Setup frame received post setup."));
          break;
        case LEASE:
        default:
          handleError(
              streamId,
              new IllegalStateException("ServerRSocket: Unexpected frame type: " + frameType));
          break;
      }
      ReferenceCountUtil.safeRelease(frame);
    } catch (Throwable t) {
      ReferenceCountUtil.safeRelease(frame);
      throw Exceptions.propagate(t);
    }
  }

  private void handleFireAndForget(int streamId, Mono<Void> result) {
    result.subscribe(
        new BaseSubscriber<Void>() {
          @Override
          protected void hookOnSubscribe(Subscription subscription) {
            sendingSubscriptions.put(streamId, subscription);
            subscription.request(Long.MAX_VALUE);
          }

          @Override
          protected void hookOnError(Throwable throwable) {}

          @Override
          protected void hookFinally(SignalType type) {
            sendingSubscriptions.remove(streamId);
          }
        });
  }

  private void handleRequestResponse(int streamId, Mono<Payload> response) {
    final BaseSubscriber<Payload> subscriber =
        new BaseSubscriber<Payload>() {
          private boolean isEmpty = true;

          @Override
          protected void hookOnNext(Payload payload) {
            if (isEmpty) {
              isEmpty = false;
            }

            if (!PayloadValidationUtils.isValid(mtu, payload, maxFrameLength)) {
              payload.release();
              cancel();
              final IllegalArgumentException t =
                  new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE);
              handleError(streamId, t);
              return;
            }

            ByteBuf byteBuf =
                PayloadFrameCodec.encodeNextCompleteReleasingPayload(allocator, streamId, payload);
            sendProcessor.onNext(byteBuf);
          }

          @Override
          protected void hookOnError(Throwable throwable) {
            if (sendingSubscriptions.remove(streamId, this)) {
              handleError(streamId, throwable);
            }
          }

          @Override
          protected void hookOnComplete() {
            if (isEmpty) {
              if (sendingSubscriptions.remove(streamId, this)) {
                sendProcessor.onNext(PayloadFrameCodec.encodeComplete(allocator, streamId));
              }
            }
          }
        };

    sendingSubscriptions.put(streamId, subscriber);
    response.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER).subscribe(subscriber);
  }

  private void handleStream(
      int streamId,
      Flux<Payload> response,
      long initialRequestN,
      @Nullable UnicastProcessor<Payload> requestChannel) {
    final BaseSubscriber<Payload> subscriber =
        new BaseSubscriber<Payload>() {

          @Override
          protected void hookOnSubscribe(Subscription s) {
            s.request(initialRequestN);
          }

          @Override
          protected void hookOnNext(Payload payload) {
            try {
              if (!PayloadValidationUtils.isValid(mtu, payload, maxFrameLength)) {
                payload.release();
                final IllegalArgumentException t =
                    new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE);

                cancelStream(t);
                return;
              }

              ByteBuf byteBuf =
                  PayloadFrameCodec.encodeNextReleasingPayload(allocator, streamId, payload);
              sendProcessor.onNext(byteBuf);
            } catch (Throwable e) {
              cancelStream(e);
            }
          }

          private void cancelStream(Throwable t) {
            // Cancel the output stream and send an ERROR frame but do not dispose the
            // requestChannel (i.e. close the connection) since the spec allows to leave
            // the channel in half-closed state.
            // specifically for requestChannel case so when Payload is invalid we will not be
            // sending CancelFrame and ErrorFrame
            // Note: CancelFrame is redundant and due to spec
            // (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel)
            // Upon receiving an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream
            // is terminated on both Requester and Responder.
            // Upon sending an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is
            // terminated on both the Requester and Responder.
            if (requestChannel != null) {
              channelProcessors.remove(streamId, requestChannel);
            }
            cancel();
            handleError(streamId, t);
          }

          @Override
          protected void hookOnComplete() {
            if (sendingSubscriptions.remove(streamId, this)) {
              sendProcessor.onNext(PayloadFrameCodec.encodeComplete(allocator, streamId));
            }
          }

          @Override
          protected void hookOnError(Throwable throwable) {
            if (sendingSubscriptions.remove(streamId, this)) {
              // specifically for requestChannel case so when Payload is invalid we will not be
              // sending CancelFrame and ErrorFrame
              // Note: CancelFrame is redundant and due to spec
              // (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel)
              // Upon receiving an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream
              // is terminated on both Requester and Responder.
              // Upon sending an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is
              // terminated on both the Requester and Responder.
              if (requestChannel != null && !requestChannel.isDisposed()) {
                if (channelProcessors.remove(streamId, requestChannel)) {
                  try {
                    requestChannel.dispose();
                  } catch (Throwable e) {
                    // ignore to ensure it does not blows up if it racing with async
                    // cancel
                  }
                }
              }

              handleError(streamId, throwable);
            }
          }
        };

    sendingSubscriptions.put(streamId, subscriber);
    response.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER).subscribe(subscriber);
  }

  private void handleChannel(int streamId, Payload payload, long initialRequestN) {
    UnicastProcessor<Payload> frames = UnicastProcessor.create(Queues.<Payload>one().get());
    channelProcessors.put(streamId, frames);

    Flux<Payload> payloads =
        frames
            .doOnRequest(
                new LongConsumer() {
                  boolean first = true;

                  @Override
                  public void accept(long l) {
                    long n;
                    if (first) {
                      first = false;
                      n = l - 1L;
                    } else {
                      n = l;
                    }
                    if (n > 0) {
                      sendProcessor.onNext(RequestNFrameCodec.encode(allocator, streamId, n));
                    }
                  }
                })
            .doFinally(
                signalType -> {
                  if (channelProcessors.remove(streamId, frames)) {
                    if (signalType == SignalType.CANCEL) {
                      sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId));
                    } else if (signalType == SignalType.ON_ERROR) {
                      Subscription subscription = sendingSubscriptions.remove(streamId);
                      if (subscription != null) {
                        subscription.cancel();
                      }
                    }
                  }
                })
            .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER);

    // not chained, as the payload should be enqueued in the Unicast processor before this method
    // returns
    // and any later payload can be processed
    frames.onNext(payload);

    if (responderRSocket != null) {
      handleStream(streamId, requestChannel(payload, payloads), initialRequestN, frames);
    } else {
      handleStream(streamId, requestChannel(payloads), initialRequestN, frames);
    }
  }

  private void handleMetadataPush(Mono<Void> result) {
    result.subscribe(
        new BaseSubscriber<Void>() {
          @Override
          protected void hookOnSubscribe(Subscription subscription) {
            subscription.request(Long.MAX_VALUE);
          }

          @Override
          protected void hookOnError(Throwable throwable) {}
        });
  }

  private void handleCancelFrame(int streamId) {
    Subscription subscription = sendingSubscriptions.remove(streamId);
    Processor<Payload, Payload> processor = channelProcessors.remove(streamId);

    if (processor != null) {
      try {
        processor.onError(new CancellationException("Disposed"));
      } catch (Exception e) {
        // ignore
      }
    }

    if (subscription != null) {
      subscription.cancel();
    }
  }

  private void handleError(int streamId, Throwable t) {
    sendProcessor.onNext(ErrorFrameCodec.encode(allocator, streamId, t));
  }

  private void handleRequestN(int streamId, ByteBuf frame) {
    Subscription subscription = sendingSubscriptions.get(streamId);

    if (subscription != null) {
      long n = RequestNFrameCodec.requestN(frame);
      subscription.request(n);
    }
  }
}
