/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.client;

import static org.mule.runtime.api.util.DataUnit.KB;
import static org.mule.runtime.http.api.HttpHeaders.Names.CONTENT_LENGTH;
import static org.mule.runtime.http.api.HttpHeaders.Names.CONTENT_TRANSFER_ENCODING;
import static org.mule.runtime.http.api.HttpHeaders.Names.CONTENT_TYPE;
import static org.mule.service.http.netty.impl.server.util.HttpParser.fromMultipartEntity;
import static org.mule.service.http.netty.impl.util.HttpUtils.buildUriString;
import static org.mule.service.http.netty.impl.util.HttpUtils.ensureSchemeAndHost;
import static org.mule.service.http.netty.impl.util.MuleToNettyUtils.calculateShouldRemoveContentLength;
import static org.mule.service.http.netty.impl.util.ReactorNettyUtils.onErrorMap;

import static java.lang.Integer.parseInt;
import static java.net.URI.create;
import static java.util.Collections.singletonMap;

import static org.slf4j.LoggerFactory.getLogger;
import static reactor.core.publisher.Mono.empty;

import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.api.streaming.Cursor;
import org.mule.runtime.api.util.MultiMap;
import org.mule.runtime.api.util.Reference;
import org.mule.runtime.http.api.client.HttpRequestOptions;
import org.mule.runtime.http.api.domain.entity.HttpEntity;
import org.mule.runtime.http.api.domain.entity.multipart.MultipartHttpEntity;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.service.http.common.client.sse.ProgressiveBodyDataListener;
import org.mule.service.http.netty.impl.message.HttpResponseCreator;
import org.mule.service.http.netty.impl.streaming.BlockingBidirectionalStream;
import org.mule.service.http.netty.impl.streaming.CancelableOutputStream;
import org.mule.service.http.netty.impl.util.ReactorNettyUtils;
import org.mule.service.http.netty.impl.util.RedirectHelper;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.RejectedExecutionException;
import java.util.function.BiFunction;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.util.AttributeKey;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.MDC;
import reactor.core.publisher.Flux;
import reactor.netty.ByteBufFlux;
import reactor.netty.http.client.HttpClient;
import reactor.netty.http.client.HttpClientResponse;

/**
 * Reactor Netty based client that has operations to send async request and receive responses on a streamed fashion
 */
public class ReactorNettyClient {

  private static final Logger LOGGER = getLogger(ReactorNettyClient.class);
  private static final int CONTENT_LENGTH_TO_HANDLE_AGGREGATED = KB.toBytes(10);
  private static final int MAX_REDIRECTS = 5;
  private static final byte[] EMPTY_BYTES = new byte[0];

  private final String clientName;
  private final HttpClient httpClient;
  private final Executor ioExecutor;
  private final boolean streamingEnabled;

  public ReactorNettyClient(String clientName, HttpClient httpClient, Executor ioExecutor, boolean streamingEnabled) {
    this.clientName = clientName;
    this.httpClient = httpClient;
    this.ioExecutor = ioExecutor;
    this.streamingEnabled = streamingEnabled;
  }

  static final AttributeKey<HttpEntity> REQUEST_ENTITY_KEY = AttributeKey.valueOf("REQUEST_ENTITY");
  static final AttributeKey<Boolean> ALWAYS_SEND_BODY_KEY = AttributeKey.valueOf("ALWAYS_SEND_BODY");
  static final AttributeKey<Boolean> REDIRECT_CHANGE_METHOD = AttributeKey.valueOf("REDIRECT_CHANGE_METHOD");

  public Flux<ByteBuf> sendAsyncRequest(HttpRequest request, HttpRequestOptions options, HttpHeaders headersToAdd,
                                        BiFunction<HttpClientResponse, ByteBufFlux, Publisher<ByteBuf>> responseFunction,
                                        CompletableFuture<HttpResponse> result) {

    RedirectHelper redirectHelper = new RedirectHelper(headersToAdd);
    URI uri = uriWithQueryParams(request);
    Map<String, String> propagatedMdc = MDC.getCopyOfContextMap();
    Reference<Map<String, String>> overriddenMDC = new Reference<>();
    LOGGER.debug("Sending request to {} with headers {}", uri, headersToAdd);
    return httpClient
        .followRedirect((req, res) -> {
          if (req.redirectedFrom().length >= MAX_REDIRECTS) {
            throw new MuleRuntimeException(new MaxRedirectException("Max redirects limit reached."));
          }
          return redirectHelper.isRedirectStatusCode(res.status().code()) && options.isFollowsRedirect();
        }, redirectHelper::addCookiesToRedirectedRequest)
        .doOnRedirect((response, connection) -> {
          redirectHelper.handleRedirectResponse(response);
          rewindStreamContent(request.getEntity());
        })
        .doOnConnected(connection -> {
          overriddenMDC.set(MDC.getCopyOfContextMap());
          MDC.setContextMap(propagatedMdc);
          connection.channel().attr(ALWAYS_SEND_BODY_KEY).set(options.shouldSendBodyAlways());
          connection.channel().attr(REDIRECT_CHANGE_METHOD).set(redirectHelper.shouldChangeMethod());
          connection.channel().attr(AttributeKey.valueOf("removeContentLength"))
              .set(calculateShouldRemoveContentLength(request));
          if (request.getEntity() != null && isExpect(headersToAdd)) {
            connection.channel().attr(REQUEST_ENTITY_KEY).set(request.getEntity());
          }
        })

        .responseTimeout(Duration.ofMillis(options.getResponseTimeout()))
        .headers(h -> h.add(headersToAdd))
        .request(HttpMethod.valueOf(request.getMethod()))
        // Pass a URI instead of a String to avoid inefficient Netty regex matching.
        // see https://github.com/reactor/reactor-netty/issues/829
        .uri(uri)
        .send((httpClientRequest, nettyOutbound) -> {
          if (isExpect(httpClientRequest.requestHeaders())) {
            return nettyOutbound.send(empty());
          } else if (request.getEntity() != null) {
            // The predicate returning false is needed because if no predicate is passed, `nettyOutbound` (which is an instance of
            // `HttpClientOperations`) will wait until the whole response content is available. False is returned to avoid forcing
            // a flush and let it happen async.
            return nettyOutbound.send(entityPublisher(request), buffer -> false);
          }
          return null;
        })
        .response(responseFunction)
        .onErrorMap(ReactorNettyUtils::onErrorMap)
        .doOnError(result::completeExceptionally)
        .doOnComplete(() -> MDC.setContextMap(overriddenMDC.get()))
        .onErrorComplete();
  }

  private boolean isExpect(HttpHeaders entries) {
    return entries.contains("expect");
  }

  private static URI uriWithQueryParams(HttpRequest request) {
    // Only create a new URI if there is some query parameter to add.
    MultiMap<String, String> queryParams = request.getQueryParams();
    if (queryParams.isEmpty()) {
      return ensureSchemeAndHost(request.getUri());
    } else {
      return create(buildUriString(request.getUri(), queryParams));
    }
  }

  private Publisher<? extends ByteBuf> entityPublisher(HttpRequest request) {
    HttpEntity entity = request.getEntity();
    if (entity.getBytesLength().isPresent() && entity.getBytesLength().getAsLong() == 0) {
      return empty();
    } else if (entity.isComposed()) {
      try {
        return new ChunkedHttpEntityPublisher(fromMultipartEntity(request.getHeaderValue(CONTENT_TYPE),
                                                                  (MultipartHttpEntity) entity,
                                                                  ct -> {
                                                                  },
                                                                  singletonMap(CONTENT_TRANSFER_ENCODING, "binary")));
      } catch (IOException e) {
        return new ChunkedHttpEntityPublisher(entity);
      }
    } else {
      return new ChunkedHttpEntityPublisher(entity);
    }
  }

  public Publisher<ByteBuf> receiveContent(HttpClientResponse response, ByteBufFlux content,
                                           CompletableFuture<HttpResponse> result, ProgressiveBodyDataListener dataListener) {
    LOGGER.debug("Received response with headers {} and status {}", response.responseHeaders(), response.status());

    if (streamingEnabled) {
      return handleResponseStreaming(response, content, result, dataListener);
    } else {
      return handleResponseNonStreaming(response, content, result, dataListener);
    }
  }

  void prepareContentForRepeatability(HttpEntity entity) {
    if (entity != null && entity.isStreaming() && entity.getContent().markSupported()) {
      doReset(entity.getContent());
      entity.getContent().mark(0);
    }
  }

  void rewindStreamContent(HttpEntity entity) {
    if (entity != null && entity.isStreaming()) {
      if (entity.getContent() instanceof Cursor cursor) {
        try {
          cursor.seek(0);
        } catch (IOException e) {
          LOGGER.warn("Unable to perform seek(0) on input stream being sent by {}: {}", clientName, e.getMessage());
        }
      } else if (entity.getContent().markSupported()) {
        doReset(entity.getContent());
      } else {
        LOGGER.warn("Stream '{}' cannot be rewinded, payload cannot be resent by {}", entity.getContent().getClass(),
                    clientName);
      }
    }
  }

  private void doReset(InputStream getContent) {
    try {
      getContent.reset();
    } catch (IOException e) {
      LOGGER.warn("Unable to reset the input stream: {}", e.getMessage());
    }
  }

  private Publisher<ByteBuf> handleResponseStreaming(HttpClientResponse response, ByteBufFlux content,
                                                     CompletableFuture<HttpResponse> result,
                                                     ProgressiveBodyDataListener dataListener) {
    try {
      BlockingBidirectionalStream bidirectionalStream = new BlockingBidirectionalStream();
      InputStream in = bidirectionalStream.getInputStream();
      dataListener.onStreamCreated(in);
      CancelableOutputStream out = bidirectionalStream.getOutputStream();
      Flux<ByteBuf> contentFlux = content.retain().doOnNext(data -> {
        try {
          byte[] bytes = new byte[data.readableBytes()];
          data.readBytes(bytes);
          out.write(bytes); // Write bytes to the output stream
          data.release();
          if (!result.isDone()) {
            LOGGER.debug("Marked response as completed but still waiting on content");
            tryCompleteAsyncWithCallerRunsFallback(result, response, in);
          }
          dataListener.onDataAvailable(bytes.length);
        } catch (Exception e) {
          result.completeExceptionally(e);
          data.release();
        }
      }).doOnComplete(() -> {
        try {
          LOGGER.debug("Marked response as completed");
          out.close();
          if (!result.isDone()) {
            tryCompleteAsyncWithCallerRunsFallback(result, response, in);
          }
          dataListener.onEndOfStream();
        } catch (Exception e) {
          result.completeExceptionally(e);
        }
      }).doOnError(error -> {
        Throwable mappedError = onErrorMap(error);
        out.cancel(mappedError);
        result.completeExceptionally(mappedError);
        dataListener.onEndOfStream();
      });
      if (!result.isDone()) {
        tryCompleteAsyncWithCallerRunsFallback(result, response, in);
      }
      return contentFlux;
    } catch (Exception e) {
      result.completeExceptionally(e);
    }
    return content;
  }

  private void tryCompleteAsyncWithCallerRunsFallback(CompletableFuture<HttpResponse> result, HttpClientResponse response,
                                                      InputStream in) {
    try {
      result.completeAsync(() -> new HttpResponseCreator().create(response, in), ioExecutor);
    } catch (RejectedExecutionException e) {
      // On scheduler rejection fall back to complete in the current thread
      result.complete(new HttpResponseCreator().create(response, in));
    }
  }

  private Publisher<ByteBuf> handleResponseNonStreaming(HttpClientResponse response, ByteBufFlux content,
                                                        CompletableFuture<HttpResponse> result,
                                                        ProgressiveBodyDataListener dataListener) {
    return content.aggregate()
        .doOnError(error -> result.completeExceptionally(onErrorMap(error)))
        .doOnSuccess(byteBuf -> {
          ByteArrayInputStream inputStream;
          int readable = (byteBuf == null) ? 0 : byteBuf.readableBytes();
          if (readable == 0) {
            inputStream = new ByteArrayInputStream(EMPTY_BYTES);
          } else {
            int readerIndex = byteBuf.readerIndex();
            if (byteBuf.hasArray()) {
              inputStream = new ByteArrayInputStream(
                                                     byteBuf.array(),
                                                     byteBuf.arrayOffset() + readerIndex,
                                                     readable);
            } else {
              byte[] bytes = new byte[readable];
              byteBuf.getBytes(readerIndex, bytes);
              inputStream = new ByteArrayInputStream(bytes);
            }
          }
          result.complete(new HttpResponseCreator().create(response, inputStream));
          dataListener.onStreamCreated(inputStream);
          dataListener.onDataAvailable(readable);
          dataListener.onEndOfStream();
        });
  }

}
