/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.server;

import static io.netty.buffer.ByteBufUtil.getBytes;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.runtime.http.api.server.RequestHandler;
import org.mule.service.http.netty.impl.message.NettyHttpRequestAdapter;
import org.mule.service.http.netty.impl.server.util.DefaultServerAddress;
import org.mule.service.http.netty.impl.server.util.HttpListenerRegistry;
import org.mule.service.http.netty.impl.streaming.BlockingBidirectionalStream;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.URISyntaxException;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpObject;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.ReferenceCountUtil;
import org.slf4j.Logger;

/**
 * Implementation of Netty inbound handler that adapts the object to the Mule HTTP API and forwards them to a
 * {@link HttpListenerRegistry}.
 * <p>
 * Note: When the Mule HTTP Service is used to implement a connector, this is the point where the request is forwarded to the flow
 * processors chain.
 */
public class ForwardingToListenerHandler extends SimpleChannelInboundHandler<HttpObject> {

  private static final Logger LOGGER = getLogger(ForwardingToListenerHandler.class);

  private final HttpListenerRegistry httpListenerRegistry;
  private final SslHandler sslHandler;

  private OutputStream currentRequestContentSink;

  public ForwardingToListenerHandler(HttpListenerRegistry httpListenerRegistry, SslHandler sslHandler) {
    this.httpListenerRegistry = httpListenerRegistry;
    this.sslHandler = sslHandler;

    this.currentRequestContentSink = null;
  }

  @Override
  protected void channelRead0(ChannelHandlerContext ctx, HttpObject httpObject) throws Exception {
    if (httpObject instanceof HttpRequest) {
      HttpRequest httpRequest = (HttpRequest) httpObject;

      InetSocketAddress socketAddress = (InetSocketAddress) ctx.channel().localAddress();
      DefaultServerAddress serverAddress = new DefaultServerAddress(socketAddress.getAddress(), socketAddress.getPort());
      org.mule.runtime.http.api.domain.message.request.HttpRequest muleRequest = nettyToMuleRequest(httpRequest, socketAddress);

      try {
        RequestHandler requestHandler = getRequestHandler(httpRequest, serverAddress, muleRequest);
        requestHandler.handleRequest(new NettyHttpRequestContext(muleRequest, ctx, sslHandler),
                                     new NettyHttp1RequestReadyCallback(ctx, muleRequest));
      } catch (Exception exception) {
        if (exception instanceof RuntimeException && exception.getCause() instanceof URISyntaxException) {
          // if the URL is malformed we want to get the correct response
          handleMalformedUri(ctx, httpObject, ((URISyntaxException) exception.getCause()));
        } else {
          throw exception;
        }
      }
    }

    if (httpObject instanceof HttpContent && !(httpObject instanceof FullHttpRequest)) {
      HttpContent content = (HttpContent) httpObject;
      handleContent(content);
    }
  }

  private void handleContent(HttpContent content) throws IOException {
    byte[] frameData = getBytes(content.content());
    try {
      currentRequestContentSink.write(frameData, 0, frameData.length);
    } catch (IOException exception) {
      if ("Trying to write in a closed buffer".equals(exception.getMessage())) {
        LOGGER.info("Nobody is reading the payload, so we are ignoring part of the content...");
      } else {
        throw exception;
      }
    }

    if (content instanceof LastHttpContent) {
      currentRequestContentSink.close();
      currentRequestContentSink = null;
    }
  }

  private RequestHandler getRequestHandler(HttpRequest nettyRequest,
                                           DefaultServerAddress serverAddress,
                                           org.mule.runtime.http.api.domain.message.request.HttpRequest muleRequest) {
    if (nettyRequest.decoderResult().isFailure()) {
      RequestHandler requestHandler = httpListenerRegistry.getErrorHandler(nettyRequest.decoderResult().cause());
      if (requestHandler != null) {
        return requestHandler;
      }
    }
    return httpListenerRegistry.getRequestHandler(serverAddress, muleRequest);
  }

  @Override
  public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
    LOGGER.error("Exception caught", cause);
    ctx.close();
  }

  private org.mule.runtime.http.api.domain.message.request.HttpRequest nettyToMuleRequest(HttpRequest httpRequest,
                                                                                          InetSocketAddress localAddress) {
    return new NettyHttpRequestAdapter(httpRequest, localAddress, createContent(httpRequest));
  }

  private InputStream createContent(HttpRequest httpRequest) {
    if (httpRequest instanceof FullHttpRequest) {
      FullHttpRequest fullHttpRequest = (FullHttpRequest) httpRequest;
      return new ByteBufInputStream((fullHttpRequest).content().retainedDuplicate());
    } else {
      BlockingBidirectionalStream blockingBuffer = new BlockingBidirectionalStream();
      currentRequestContentSink = blockingBuffer.getOutputStream();
      return blockingBuffer.getInputStream();
    }
  }

  // Added for testing purposes.
  SslHandler getSslHandler() {
    return sslHandler;
  }

  private static void handleMalformedUri(ChannelHandlerContext ctx, HttpObject httpObject, URISyntaxException exception) {
    HttpRequest httpRequest = (HttpRequest) httpObject;
    ByteBuf responseMsg = Unpooled.buffer();
    ByteBufUtil.writeUtf8(responseMsg,
                          String.format("HTTP request parsing failed with error: \"%s\"",
                                        exception.getMessage()));
    HttpResponse rejection =
        new DefaultFullHttpResponse(httpRequest.protocolVersion(), HttpResponseStatus.BAD_REQUEST, responseMsg);
    ReferenceCountUtil.release(httpObject);
    ctx.writeAndFlush(rejection).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
  }
}
