/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.client;

import static org.mule.runtime.api.i18n.I18nMessageFactory.createStaticMessage;
import static org.mule.service.http.netty.impl.client.ReactorNettyClient.REQUEST_ENTITY_KEY;

import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
import static io.netty.handler.codec.http.HttpHeaderNames.TRANSFER_ENCODING;
import static io.netty.handler.codec.http.HttpHeaderValues.CHUNKED;
import static io.netty.handler.codec.http.LastHttpContent.EMPTY_LAST_CONTENT;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.http.api.domain.entity.HttpEntity;
import org.mule.service.http.netty.impl.streaming.StatusCallback;
import org.mule.service.http.netty.impl.streaming.StreamingEntitySender;

import java.io.IOException;
import java.io.InputStream;
import java.util.OptionalLong;
import java.util.concurrent.ExecutorService;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultHttpContent;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.LastHttpContent;
import org.slf4j.Logger;

/**
 * Handles HTTP 100-Continue responses in a Netty HTTP client. If a 100-Continue response is received, it streams the request
 * payload (contained in {@link HttpEntity}) to the channel. This is useful in scenarios where the server asks the client to send
 * the request body only after an initial acknowledgment (100 Continue).
 *
 * <p>
 * The request entity is retrieved from the channel's context attributes using the {@link ReactorNettyClient#REQUEST_ENTITY_KEY}
 * and streamed in chunks of a specified size.
 * </p>
 *
 * <p>
 * If a response other than 100 Continue is received, it is passed through the pipeline.
 * </p>
 *
 * <p>
 * If the request entity is not found or an error occurs while reading the entity, an exception is thrown, and appropriate logging
 * is performed.
 * </p>
 */

public class ClientExpectContinueHandler extends ChannelDuplexHandler {

  private static final Logger LOGGER = getLogger(ClientExpectContinueHandler.class);
  private static final int CHUNK_SIZE = 8192;

  private final ExecutorService ioExecutor;

  private boolean receiving100ContinueResponse = false;
  private boolean suppressChannelReadComplete = false;

  public ClientExpectContinueHandler(ExecutorService ioExecutor) {
    this.ioExecutor = ioExecutor;
  }

  /**
   * Handles the incoming response and streams the request entity if a 100-Continue response is received. For other response
   * statuses, it passes the message down the pipeline.
   *
   * @param ctx the {@link ChannelHandlerContext} which provides access to the pipeline.
   * @param msg the incoming message, typically an {@link HttpResponse}.
   */
  @Override
  public void channelRead(ChannelHandlerContext ctx, Object msg) throws IOException {
    if (msg instanceof HttpResponse res && res.status().code() == 100) {
      LOGGER.debug("Received 100 Continue response {}", msg);
      this.receiving100ContinueResponse = true;
      this.suppressChannelReadComplete = true;
    } else if (msg instanceof LastHttpContent && this.receiving100ContinueResponse) {
      HttpEntity requestEntity = ctx.channel().attr(REQUEST_ENTITY_KEY).get();
      LOGGER.debug("Request entity from the channel is {}", requestEntity);
      if (requestEntity != null) {
        writeEntityToChannel(ctx, requestEntity);
        LOGGER.debug("Streaming request entity payload to channel");
      } else {
        LOGGER.warn("No request entity found for 100 Continue response");
      }
      this.receiving100ContinueResponse = false;
      return;
    }
    if (this.receiving100ContinueResponse) {
      return;
    }

    // we are not handling a 100-Continue response, the rest of the pipeline will know how to handle this...
    ctx.fireChannelRead(msg);
  }

  /**
   * Writes the content of the {@link HttpEntity} to the channel in chunks.
   *
   * @param ctx           the {@link ChannelHandlerContext} to write the streamed content to the channel.
   * @param requestEntity the {@link HttpEntity} containing the payload to be streamed.
   */
  private void writeEntityToChannel(ChannelHandlerContext ctx, HttpEntity requestEntity) throws IOException {
    if (requestEntity.isStreaming()) {
      sendStreamingEntity(ctx, requestEntity);
    } else {
      sendNonStreamingEntity(ctx, requestEntity);
    }
  }

  private void sendStreamingEntity(ChannelHandlerContext ctx, HttpEntity requestEntity) throws IOException {
    StreamingEntitySender entitySender =
        new StreamingEntitySender(requestEntity, ctx, () -> LOGGER.debug("Starting to write chunk to channel context"),
                                  new EntitySenderStatusCallback(), ioExecutor);
    entitySender.sendNextChunk();
  }

  private void sendNonStreamingEntity(ChannelHandlerContext ctx, HttpEntity requestEntity) {
    try (InputStream inputStream = validateInputStream(requestEntity)) {
      streamEntityChunks(ctx, inputStream);
    } catch (IOException e) {
      throw new MuleRuntimeException(createStaticMessage("Failed to read content from HttpEntity:"), e);
    }
  }

  private InputStream validateInputStream(HttpEntity requestEntity) throws IOException {
    InputStream inputStream = requestEntity.getContent();
    if (inputStream == null) {
      throw new MuleRuntimeException(new IOException("InputStream is null"));
    }
    return inputStream;
  }

  private void streamEntityChunks(ChannelHandlerContext ctx, InputStream inputStream) throws IOException {
    byte[] buffer = new byte[CHUNK_SIZE];
    int bytesRead;
    while ((bytesRead = inputStream.read(buffer)) != -1) {
      ByteBuf chunk = Unpooled.wrappedBuffer(buffer, 0, bytesRead);
      ctx.write(new DefaultHttpContent(chunk));
    }
    ctx.writeAndFlush(EMPTY_LAST_CONTENT);
    LOGGER.debug("Successfully streamed HttpEntity in chunks");
  }

  @Override
  public void channelReadComplete(ChannelHandlerContext ctx) {
    if (this.suppressChannelReadComplete) {
      this.suppressChannelReadComplete = false;
      if (!ctx.channel().config().isAutoRead()) {
        ctx.read();
      }
    } else {
      ctx.fireChannelReadComplete();
    }
  }

  @Override
  public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
    if (shouldAdaptRequest(ctx, msg)) {
      handleAdaptedRequest(ctx, (HttpRequest) msg, promise);
    } else {
      super.write(ctx, msg, promise);
    }
  }

  private boolean shouldAdaptRequest(ChannelHandlerContext ctx, Object msg) {
    return ctx.channel().hasAttr(REQUEST_ENTITY_KEY) && msg instanceof HttpRequest;
  }

  private void handleAdaptedRequest(ChannelHandlerContext ctx, HttpRequest httpRequest, ChannelPromise promise) {
    HttpEntity httpEntity = ctx.channel().attr(REQUEST_ENTITY_KEY).get();
    if (httpEntity != null) {
      HttpRequest adaptedRequest = adaptRequest(httpRequest, httpEntity.getBytesLength());
      ctx.write(adaptedRequest, promise);
    }
  }

  private HttpRequest adaptRequest(HttpRequest httpRequest, OptionalLong bytesLength) {
    HttpHeaders headers = httpRequest.headers();
    if (headers.contains(CONTENT_LENGTH)) {
      if (bytesLength.isPresent()) {
        LOGGER.debug("Setting the actual content-length header to the request");
        headers.set(CONTENT_LENGTH, bytesLength.getAsLong());
      } else {
        // At this point, we know that reactor-netty replaced the transfer encoding header with a Content-Length=0 because we
        // removed the entity from the request. We need to remove the entity from the request because otherwise reactor-netty
        // doesn't wait for the 100-Continue before sending the entity. That is precisely the purpose of this whole class.
        LOGGER.debug("Setting Transfer-Encoding=chunked header");
        headers.remove(CONTENT_LENGTH);
        headers.add(TRANSFER_ENCODING, CHUNKED);
      }
    }
    if (httpRequest instanceof LastHttpContent) {
      LOGGER.debug("LastHttpContent received, writing DefaultHttpRequest to the channel");
      // If this is a "last" content, then Netty will interpret that the request ends with this object, and that's not
      // the case.
      return new DefaultHttpRequest(httpRequest.protocolVersion(), httpRequest.method(), httpRequest.uri(), headers);
    } else {
      // Just return the same request.
      return httpRequest;
    }
  }

  private static class EntitySenderStatusCallback implements StatusCallback {

    @Override
    public void onFailure(Throwable exception) {
      logStreamingError(exception);
    }

    @Override
    public void onSuccess() {
      LOGGER.debug("Request sent successfully to server in chunks.");
    }

    private void logStreamingError(Throwable exception) {
      LOGGER.warn("Error while sending streaming request to server: {}", exception.getMessage());
      if (LOGGER.isDebugEnabled()) {
        LOGGER.debug("Exception thrown while sending streaming request to server", exception);
      }
    }
  }
}
