/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.client;

import static org.mule.runtime.api.util.DataUnit.KB;
import static org.mule.runtime.http.api.HttpHeaders.Names.CONTENT_LENGTH;
import static org.mule.service.http.netty.impl.util.HttpUtils.buildUriString;
import static org.mule.service.http.netty.impl.util.MuleToNettyUtils.calculateShouldRemoveContentLength;
import static org.mule.service.http.netty.impl.util.ReactorNettyUtils.onErrorMap;

import static java.lang.Integer.parseInt;
import static java.net.URI.create;

import static io.netty.buffer.ByteBufUtil.getBytes;
import static org.slf4j.LoggerFactory.getLogger;
import static reactor.core.publisher.Mono.empty;

import org.mule.runtime.api.util.MultiMap;
import org.mule.runtime.http.api.client.HttpRequestOptions;
import org.mule.runtime.http.api.domain.entity.HttpEntity;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.service.http.netty.impl.message.HttpResponseCreator;
import org.mule.service.http.netty.impl.util.ReactorNettyUtils;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.net.URI;
import java.time.Duration;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.function.BiFunction;

import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.util.AttributeKey;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import reactor.core.publisher.Flux;
import reactor.netty.ByteBufFlux;
import reactor.netty.http.client.HttpClient;
import reactor.netty.http.client.HttpClientResponse;

/**
 * Reactor Netty based client that has operations to send async request and receive responses on a streamed fashion
 */
public class ReactorNettyClient {

  private static final Logger LOGGER = getLogger(ReactorNettyClient.class);
  private static final int DEFAULT_RESPONSE_RECEPTION_BUFFER_SIZE = KB.toBytes(10);
  private static final int CONTENT_LENGTH_TO_HANDLE_AGGREGATED = KB.toBytes(10);

  // TODO: Maybe we need to tune this
  private final Executor resultExecutor = Executors.newFixedThreadPool(10);
  private final HttpClient httpClient;

  public ReactorNettyClient(HttpClient httpClient) {
    this.httpClient = httpClient;
  }

  static final AttributeKey<HttpEntity> REQUEST_ENTITY_KEY = AttributeKey.valueOf("REQUEST_ENTITY");

  public Flux<ByteBuf> sendAsyncRequest(HttpRequest request, HttpRequestOptions options, HttpHeaders headersToAdd,
                                        BiFunction<HttpClientResponse, ByteBufFlux, Publisher<ByteBuf>> responseFunction,
                                        CompletableFuture<HttpResponse> result) {
    URI uri = uriWithQueryParams(request);
    LOGGER.debug("Sending request to {} with headers {}", uri, headersToAdd);
    return httpClient
        .followRedirect(options.isFollowsRedirect())
        .doOnConnected(connection -> {
          connection.channel().attr(AttributeKey.valueOf("removeContentLength"))
              .set(calculateShouldRemoveContentLength(request));
          if (request.getEntity() != null && isExpect(headersToAdd)) {
            connection.channel().attr(REQUEST_ENTITY_KEY).set(request.getEntity());
          }
        })
        .responseTimeout(Duration.ofMillis(options.getResponseTimeout()))
        .headers(h -> h.add(headersToAdd))
        .request(HttpMethod.valueOf(request.getMethod()))
        // Pass a URI instead of a String to avoid inefficient Netty regex matching.
        // see https://github.com/reactor/reactor-netty/issues/829
        .uri(uri)
        .send((httpClientRequest, nettyOutbound) -> {
          if (isExpect(httpClientRequest.requestHeaders())) {
            return nettyOutbound.send(empty());
          } else if (request.getEntity() != null) {
            return nettyOutbound.send(entityPublisher(request.getEntity()));
          }
          return null;
        })
        .response(responseFunction)
        .onErrorMap(ReactorNettyUtils::onErrorMap)
        .doOnError(result::completeExceptionally)
        .onErrorComplete();
  }

  private boolean isExpect(HttpHeaders entries) {
    return entries.contains("expect");
  }

  private static URI uriWithQueryParams(HttpRequest request) {
    // Only create a new URI if there is some query parameter to add.
    MultiMap<String, String> queryParams = request.getQueryParams();
    if (queryParams.isEmpty()) {
      return request.getUri();
    } else {
      return create(buildUriString(request.getUri(), queryParams));
    }
  }

  private Publisher<? extends ByteBuf> entityPublisher(HttpEntity entity) {
    if (entity.getBytesLength().isPresent() && entity.getBytesLength().getAsLong() == 0) {
      return empty();
    } else {
      return new ChunkedHttpEntityPublisher(entity);
    }
  }

  public Publisher<ByteBuf> receiveContent(HttpClientResponse response, ByteBufFlux content,
                                           CompletableFuture<HttpResponse> result) {
    LOGGER.debug("Received response with headers {} and status {}", response.responseHeaders(), response.status());

    if (responseIsShortEnough(response)) {
      return handleShortResponse(response, content, result);
    } else {
      return handleResponseStreaming(response, content, result);
    }
  }

  private Publisher<ByteBuf> handleResponseStreaming(HttpClientResponse response, ByteBufFlux content,
                                                     CompletableFuture<HttpResponse> result) {
    final PipedInputStream in;
    PipedOutputStream out = new PipedOutputStream();
    try {
      in = new PipedInputStream(out, DEFAULT_RESPONSE_RECEPTION_BUFFER_SIZE);
      Flux<ByteBuf> contentFlux = content.retain().doOnNext(data -> {
        try {
          byte[] bytes = new byte[data.readableBytes()];
          data.readBytes(bytes);
          out.write(bytes); // Write bytes to the output stream
          data.release();
          if (!result.isDone()) {
            LOGGER.debug("Marked response as completed but still waiting on content");
            completeAsync(result, () -> new HttpResponseCreator().create(response, in), resultExecutor);
          }
        } catch (IOException e) {
          result.completeExceptionally(e);
          data.release();
        }
      }).doOnComplete(() -> {
        try {
          LOGGER.debug("Marked response as completed");
          out.close();
          if (!result.isDone()) {
            completeAsync(result, () -> new HttpResponseCreator().create(response, in), resultExecutor);
          }
        } catch (Exception e) {
          result.completeExceptionally(e);
        }
      }).doOnError(error -> {
        Throwable mappedError = onErrorMap(error);
        result.completeExceptionally(mappedError);
      });
      if (!result.isDone()) {
        completeAsync(result, () -> new HttpResponseCreator().create(response, in), resultExecutor);
      }
      return contentFlux;
    } catch (IOException e) {
      result.completeExceptionally(e);
    }
    return content;
  }

  private Publisher<ByteBuf> handleShortResponse(HttpClientResponse response, ByteBufFlux content,
                                                 CompletableFuture<HttpResponse> result) {
    return content.aggregate().doOnError(error -> {
      Throwable mappedError = onErrorMap(error);
      result.completeExceptionally(mappedError);
    }).doOnSuccess(byteBuf -> {
      byte[] bytes;
      if (byteBuf != null) {
        bytes = getBytes(byteBuf);
      } else {
        bytes = new byte[0];
      }
      result.complete(new HttpResponseCreator().create(response, new ByteArrayInputStream(bytes)));
    });
  }

  private boolean responseIsShortEnough(HttpClientResponse response) {
    String contentLength = response.responseHeaders().get(CONTENT_LENGTH);
    if (contentLength == null) {
      return false;
    }

    return parseInt(contentLength) < CONTENT_LENGTH_TO_HANDLE_AGGREGATED;
  }

  /**
   * We have to add this method here to provide a way to complete the result asynchronously so the content keeps being received on
   * a streaming case This functionality comes by default with Java 9+ but we need to support Java 8 for some time yet.
   */
  public static <T> CompletableFuture<T> completeAsync(CompletableFuture<T> result, Callable<? extends T> callable,
                                                       Executor executor) {
    if (result == null) {
      throw new NullPointerException();
    }

    // The second ternary operation is done just to avoid a race condition
    CompletableFuture<T> delegate =
        callAsync(callable == null ? null : () -> result.isDone() ? null : callable.call(), executor);

    if (delegate == null) {
      return null;
    }

    result.whenComplete((v, t) -> {
      if (t == null) {
        delegate.complete(v);
        return;
      }
      delegate.completeExceptionally(t);
    });

    delegate.whenComplete((v, t) -> {
      if (t == null) {
        result.complete(v);
        return;
      }
      result.completeExceptionally(t);
    });
    return result;
  }

  public static <U> CompletableFuture<U> callAsync(Callable<? extends U> callable, Executor executor) {
    return CompletableFuture.supplyAsync(callable == null ? null : () -> {
      try {
        return callable.call();
      } catch (Error | RuntimeException e) {
        throw e; // Also avoids double wrapping CompletionExceptions below.
      } catch (Throwable t) {
        throw new CompletionException(t);
      }
    }, executor);
  }
}
