/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.client;

import static org.mule.runtime.http.api.HttpHeaders.Names.CONTENT_TRANSFER_ENCODING;
import static org.mule.runtime.http.api.HttpHeaders.Names.CONTENT_TYPE;
import static org.mule.service.http.netty.impl.server.util.HttpParser.fromMultipartEntity;
import static org.mule.service.http.netty.impl.util.HttpResponseCreatorUtils.trailersAsFuture;
import static org.mule.service.http.netty.impl.util.HttpUtils.buildUriString;
import static org.mule.service.http.netty.impl.util.HttpUtils.ensureSchemeAndHost;
import static org.mule.service.http.netty.impl.util.MuleToNettyUtils.calculateShouldRemoveContentLength;
import static org.mule.service.http.netty.impl.util.NettyUtils.toNioBuffer;
import static org.mule.service.http.netty.impl.util.ReactorNettyUtils.onErrorMap;

import static java.net.URI.create;
import static java.util.Collections.singletonMap;

import static io.netty.buffer.Unpooled.wrappedBuffer;
import static io.netty.util.ReferenceCountUtil.release;
import static org.slf4j.LoggerFactory.getLogger;
import static reactor.core.publisher.FluxSink.OverflowStrategy.ERROR;
import static reactor.core.publisher.Mono.empty;

import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.api.streaming.Cursor;
import org.mule.runtime.api.util.MultiMap;
import org.mule.runtime.api.util.Reference;
import org.mule.runtime.http.api.client.HttpRequestOptions;
import org.mule.runtime.http.api.domain.entity.HttpEntity;
import org.mule.runtime.http.api.domain.entity.multipart.MultipartHttpEntity;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.service.http.netty.impl.message.HttpResponseCreator;
import org.mule.service.http.netty.impl.message.ReactorNettyResponseWrapper;
import org.mule.service.http.netty.impl.util.ReactorNettyUtils;
import org.mule.service.http.netty.impl.util.RedirectHelper;

import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.RejectedExecutionException;
import java.util.function.BiFunction;

import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.util.AttributeKey;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.MDC;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.netty.ByteBufFlux;
import reactor.netty.http.client.HttpClient;
import reactor.netty.http.client.HttpClientResponse;

/**
 * Reactor Netty based client that has operations to send async request and receive responses on a streamed fashion
 */
public class ReactorNettyClient {

  private static final Logger LOGGER = getLogger(ReactorNettyClient.class);
  private static final int MAX_REDIRECTS = 5;

  static final AttributeKey<CompletableFuture<MultiMap<String, String>>> TRAILERS_FUTURE =
      AttributeKey.valueOf("TRAILERS_FUTURE");

  private final String clientName;
  private final HttpClient httpClient;
  private final Executor ioExecutor;
  private final boolean streamingEnabled;

  public ReactorNettyClient(String clientName, HttpClient httpClient, Executor ioExecutor, boolean streamingEnabled) {
    this.clientName = clientName;
    this.httpClient = httpClient;
    this.ioExecutor = ioExecutor;
    this.streamingEnabled = streamingEnabled;
  }

  static final AttributeKey<HttpEntity> REQUEST_ENTITY_KEY = AttributeKey.valueOf("REQUEST_ENTITY");
  static final AttributeKey<Boolean> ALWAYS_SEND_BODY_KEY = AttributeKey.valueOf("ALWAYS_SEND_BODY");
  static final AttributeKey<Boolean> REDIRECT_CHANGE_METHOD = AttributeKey.valueOf("REDIRECT_CHANGE_METHOD");

  public Flux<ByteBuf> sendAsyncRequest(HttpRequest request, HttpRequestOptions options, HttpHeaders headersToAdd,
                                        BiFunction<HttpClientResponse, ByteBufFlux, Publisher<ByteBuf>> responseFunction,
                                        CompletableFuture<HttpResponse> result) {

    RedirectHelper redirectHelper = new RedirectHelper(headersToAdd);
    URI uri = uriWithQueryParams(request);
    Map<String, String> propagatedMdc = MDC.getCopyOfContextMap();
    Reference<Map<String, String>> overriddenMDC = new Reference<>();
    LOGGER.debug("Sending request to {} with headers {}", uri, headersToAdd);
    return httpClient
        .followRedirect((req, res) -> {
          if (req.redirectedFrom().length >= MAX_REDIRECTS) {
            throw new MuleRuntimeException(new MaxRedirectException("Max redirects limit reached."));
          }
          return redirectHelper.isRedirectStatusCode(res.status().code()) && options.isFollowsRedirect();
        }, redirectHelper::addCookiesToRedirectedRequest)
        .doOnRedirect((response, connection) -> {
          redirectHelper.handleRedirectResponse(response);
          // Rewind the stream if:
          // 1. Method doesn't change (307, 308) - body will be resent, OR
          // 2. shouldSendBodyAlways is true - body will be sent even for GET redirects
          boolean shouldPreserveMethod = Boolean.FALSE.equals(redirectHelper.shouldChangeMethod());
          if (shouldPreserveMethod || options.shouldSendBodyAlways()) {
            rewindStreamContent(request.getEntity());
          }
        })
        .doOnConnected(connection -> {
          overriddenMDC.set(MDC.getCopyOfContextMap());
          MDC.setContextMap(propagatedMdc);
          connection.channel().attr(ALWAYS_SEND_BODY_KEY).set(options.shouldSendBodyAlways());
          connection.channel().attr(REDIRECT_CHANGE_METHOD).set(redirectHelper.shouldChangeMethod());
          connection.channel().attr(AttributeKey.valueOf("removeContentLength"))
              .set(calculateShouldRemoveContentLength(request));
          if (request.getEntity() != null && isExpect(headersToAdd)) {
            connection.channel().attr(REQUEST_ENTITY_KEY).set(request.getEntity());
          }
          connection.channel().attr(TRAILERS_FUTURE).set(trailersAsFuture(request.getEntity()));
        })
        .responseTimeout(Duration.ofMillis(options.getResponseTimeout()))
        .headers(h -> h.add(headersToAdd))
        .request(HttpMethod.valueOf(request.getMethod()))
        // Pass a URI instead of a String to avoid inefficient Netty regex matching.
        // see https://github.com/reactor/reactor-netty/issues/829
        .uri(uri)
        .send((httpClientRequest, nettyOutbound) -> {
          if (isExpect(httpClientRequest.requestHeaders())) {
            return nettyOutbound.send(empty());
          } else if (request.getEntity() != null) {
            // The predicate returning false is needed because if no predicate is passed, `nettyOutbound` (which is an instance of
            // `HttpClientOperations`) will wait until the whole response content is available. False is returned to avoid forcing
            // a flush and let it happen async.
            return nettyOutbound.send(entityPublisher(request), buffer -> false);
          }
          return null;
        })
        .response(responseFunction)
        .onErrorMap(ReactorNettyUtils::onErrorMap)
        .doOnError(result::completeExceptionally)
        .doOnComplete(() -> MDC.setContextMap(overriddenMDC.get()))
        .onErrorComplete();
  }

  private boolean isExpect(HttpHeaders entries) {
    return entries.contains("expect");
  }

  private static URI uriWithQueryParams(HttpRequest request) {
    // Only create a new URI if there is some query parameter to add.
    MultiMap<String, String> queryParams = request.getQueryParams();
    if (queryParams.isEmpty()) {
      return ensureSchemeAndHost(request.getUri());
    } else {
      return create(buildUriString(request.getUri(), queryParams));
    }
  }

  private Publisher<? extends ByteBuf> entityPublisher(HttpRequest request) {
    HttpEntity entity = request.getEntity();
    if (entity.getBytesLength().isPresent() && entity.getBytesLength().getAsLong() == 0) {
      return empty();
    } else if (entity.isComposed()) {
      try {
        return new ChunkedHttpEntityPublisher(fromMultipartEntity(request.getHeaderValue(CONTENT_TYPE),
                                                                  (MultipartHttpEntity) entity,
                                                                  ct -> {
                                                                  },
                                                                  singletonMap(CONTENT_TRANSFER_ENCODING, "binary")));
      } catch (IOException e) {
        return new ChunkedHttpEntityPublisher(entity);
      }
    } else if (entity.isReactive()) {
      return reactiveEntityToByteBufFlux(entity);
    } else {
      return new ChunkedHttpEntityPublisher(entity);
    }
  }

  public Publisher<ByteBuf> receiveContent(HttpClientResponse response, ByteBufFlux content,
                                           CompletableFuture<HttpResponse> result) {
    LOGGER.debug("Received response with headers {} and status {}", response.responseHeaders(), response.status());

    if (streamingEnabled) {
      return handleResponseStreaming(response, content, result);
    } else {
      return handleResponseNonStreaming(response, content, result);
    }
  }

  void prepareContentForRepeatability(HttpEntity entity) {
    if (entity != null && !entity.isReactive() && entity.isStreaming() && entity.getContent().markSupported()) {
      doReset(entity.getContent());
      entity.getContent().mark(0);
    }
  }

  void rewindStreamContent(HttpEntity entity) {
    if (entity != null && entity.isStreaming()) {
      if (entity.getContent() instanceof Cursor cursor) {
        try {
          cursor.seek(0);
        } catch (IOException e) {
          LOGGER.warn("Unable to perform seek(0) on input stream being sent by {}: {}", clientName, e.getMessage());
        }
      } else if (entity.getContent().markSupported()) {
        doReset(entity.getContent());
      } else {
        LOGGER.warn("Stream '{}' cannot be rewinded, payload cannot be resent by {}", entity.getContent().getClass(),
                    clientName);
      }
    }
  }

  private void doReset(InputStream getContent) {
    try {
      getContent.reset();
    } catch (IOException e) {
      LOGGER.warn("Unable to reset the input stream: {}", e.getMessage());
    }
  }

  private Publisher<ByteBuf> handleResponseStreaming(HttpClientResponse response, ByteBufFlux content,
                                                     CompletableFuture<HttpResponse> result) {
    try {
      ReactorNettyResponseWrapper feedableResponse = new HttpResponseCreator().createFeedable(response);
      tryCompleteAsyncWithCallerRunsFallback(result, feedableResponse);

      return content.retain().doOnNext(data -> {
        try {
          feedableResponse.feed(toNioBuffer(data));
        } catch (Exception e) {
          result.completeExceptionally(e);
          feedableResponse.error(e);
        } finally {
          release(data);
        }
      }).doOnComplete(() -> {
        try {
          trailersAsFuture(response).whenComplete((trailers, throwable) -> {
            try {
              if (throwable != null) {
                feedableResponse.error((Exception) throwable);
                return;
              }

              if (trailers.isEmpty()) {
                feedableResponse.complete();
                return;
              }

              feedableResponse.completeWithTrailers(trailers);
            } catch (Exception e) {
              feedableResponse.error(e);
            }
          });
        } catch (Exception e) {
          feedableResponse.error(e);
        }
      }).doOnError(error -> {
        var mappedError = onErrorMap(error);
        feedableResponse.error(mappedError);
      });
    } catch (Exception e) {
      result.completeExceptionally(e);
    }
    return content;
  }

  private void tryCompleteAsyncWithCallerRunsFallback(CompletableFuture<HttpResponse> result,
                                                      ReactorNettyResponseWrapper response) {
    try {
      result.completeAsync(() -> response, ioExecutor);
    } catch (RejectedExecutionException e) {
      // On scheduler rejection fall back to complete in the current thread
      result.complete(response);
    }
  }

  private Publisher<ByteBuf> handleResponseNonStreaming(HttpClientResponse response, ByteBufFlux content,
                                                        CompletableFuture<HttpResponse> result) {
    return content.aggregate()
        .doOnError(error -> result.completeExceptionally(onErrorMap(error)))
        .doOnSuccess(byteBuf -> result.complete(new HttpResponseCreator().create(response, byteBuf)));
  }

  private Flux<ByteBuf> reactiveEntityToByteBufFlux(HttpEntity entity) {
    return Flux.create(sink -> propagateDataFromEntityToSink(entity, sink), ERROR);
  }

  private void propagateDataFromEntityToSink(HttpEntity entity, FluxSink<ByteBuf> sink) {
    entity.onData(data -> sink.next(wrappedBuffer(data)));
    entity.onComplete((ts, err) -> {
      if (err != null) {
        sink.error(err);
      } else {
        sink.complete();
      }
    });
  }
}
