/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.server;

import static org.mule.runtime.http.api.HttpHeaders.Names.CONNECTION;
import static org.mule.runtime.http.api.HttpHeaders.Names.CONTENT_LENGTH;

import static io.netty.handler.codec.http.HttpHeaderValues.KEEP_ALIVE;
import static io.netty.handler.codec.http.HttpResponseStatus.REQUEST_TIMEOUT;
import static io.netty.handler.codec.http.HttpUtil.isContentLengthSet;
import static io.netty.handler.codec.http.HttpUtil.isKeepAlive;
import static io.netty.handler.codec.http.HttpUtil.isTransferEncodingChunked;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_0;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import static io.netty.handler.timeout.IdleState.ALL_IDLE;
import static io.netty.handler.timeout.IdleState.READER_IDLE;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.socket.ChannelInputShutdownReadComplete;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpStatusClass;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.timeout.IdleStateEvent;

/**
 * Handler to close server-side connections when needed. It also adds the corresponding {@code "Connection"} headers.
 * <p>
 * It's mainly based in Netty's built-in {@link io.netty.handler.codec.http.HttpServerKeepAliveHandler}, but with two differences:
 * <ul>
 * <li>Drops the connection under certain status codes</li>
 * <li>Handles the "Connection" headers in a Grizzly compatible fashion to preserve backwards with the old HTTP Service
 * implementation</li>
 * </ul>
 */
public class KeepAliveHandler extends ChannelDuplexHandler {

  private static final String MULTIPART_PREFIX = "multipart";
  public static final String TIMEOUT_READING_REQUEST = "Timeout reading request";

  private boolean persistentConnection;
  // Track pending responses to support client pipelining: https://tools.ietf.org/html/rfc7230#section-6.3.2
  private int pendingResponses;
  private boolean isInputShutdown = false;
  private HttpVersion protocolVersion = HTTP_1_1;

  public KeepAliveHandler(boolean usePersistentConnections) {
    this.persistentConnection = usePersistentConnections;
  }

  @Override
  public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
    // modify message on way out to add headers if needed
    if (msg instanceof HttpResponse) {
      final HttpResponse response = (HttpResponse) msg;
      trackResponse(response);

      if (persistentConnection && !response.headers().contains(CONNECTION) && !response.protocolVersion().isKeepAliveDefault()) {
        response.headers().set(CONNECTION, KEEP_ALIVE);
      }

      if (!isKeepAlive(response) || !isSelfDefinedMessageLength(response) || statusDropsConnection(response.status().code())) {
        // No longer keep alive as the client can't tell when the message is done unless we close connection
        pendingResponses = 0;
        persistentConnection = false;
      }
      // Server might think it can keep connection alive, but we should fix response header if we know better
      if (!shouldKeepAlive()) {
        response.headers().set(CONNECTION, HttpHeaderValues.CLOSE);
      }
    }
    if (msg instanceof LastHttpContent && !shouldKeepAlive()) {
      promise = promise.unvoid().addListener(ChannelFutureListener.CLOSE);
    }
    super.write(ctx, msg, promise);
  }

  @Override
  public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
    // read message and track if it was keepAlive
    if (msg instanceof HttpRequest request) {
      pendingResponses += 1;
      if (persistentConnection) {
        persistentConnection = isKeepAlive(request);
      }
      protocolVersion = request.protocolVersion();
    }
    ctx.fireChannelRead(msg);
  }

  private void trackResponse(HttpResponse response) {
    if (!isInformational(response)) {
      pendingResponses -= 1;
    }
  }

  private boolean shouldKeepAlive() {
    if (isInputShutdown) {
      return false;
    }
    return pendingResponses > 0 || persistentConnection;
  }

  /**
   * Keep-alive only works if the client can detect when the message has ended without relying on the connection being closed.
   * <p>
   * <ul>
   * <li>See <a href="https://tools.ietf.org/html/rfc7230#section-6.3"/></li>
   * <li>See <a href="https://tools.ietf.org/html/rfc7230#section-3.3.2"/></li>
   * <li>See <a href="https://tools.ietf.org/html/rfc7230#section-3.3.3"/></li>
   * </ul>
   *
   * @param response The HttpResponse to check
   *
   * @return true if the response has a self defined message length.
   */
  private static boolean isSelfDefinedMessageLength(HttpResponse response) {
    return isContentLengthSet(response) || ((isTransferEncodingChunked(response) || isMultipart(response)
        || isInformational(response) || response.status().code() == HttpResponseStatus.NO_CONTENT.code())
        && !response.protocolVersion().equals(HTTP_1_0));
  }

  private static boolean isInformational(HttpResponse response) {
    return response.status().codeClass() == HttpStatusClass.INFORMATIONAL;
  }

  private static boolean isMultipart(HttpResponse response) {
    String contentType = response.headers().get(HttpHeaderNames.CONTENT_TYPE);
    return contentType != null &&
        contentType.regionMatches(true, 0, MULTIPART_PREFIX, 0, MULTIPART_PREFIX.length());
  }

  /**
   * Determine if we must drop the connection because of the HTTP status code. Use the same list of codes as Apache/httpd.
   */
  private static boolean statusDropsConnection(int status) {
    return status == 400 /* SC_BAD_REQUEST */ ||
        status == 408 /* SC_REQUEST_TIMEOUT */ ||
        status == 499 /* SC_NGINX */ ||
        status == 411 /* SC_LENGTH_REQUIRED */ ||
        status == 413 /* SC_REQUEST_ENTITY_TOO_LARGE */ ||
        status == 414 /* SC_REQUEST_URI_TOO_LARGE */ ||
        status == 417 /* FAILED EXPECTATION */ ||
        status == 500 /* SC_INTERNAL_SERVER_ERROR */ ||
        status == 503 /* SC_SERVICE_UNAVAILABLE */ ||
        status == 501 /* SC_NOT_IMPLEMENTED */ ||
        status == 505 /* SC_VERSION_NOT_SUPPORTED */;
  }

  @Override
  public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
    if (evt instanceof IdleStateEvent idleStateEvent) {
      if (READER_IDLE == idleStateEvent.state() && pendingResponses > 0) {
        // The reader idle timeout should only apply if there is an inflight request
        sendTimeoutResponse(ctx);
        return;
      }
      if (ALL_IDLE == idleStateEvent.state() && pendingResponses == 0) {
        // The connection idle timeout should only apply if there is NOT an inflight request
        ctx.close();
        return;
      }
    }
    if (evt instanceof ChannelInputShutdownReadComplete) {
      isInputShutdown = true;
      if (pendingResponses == 0) {
        ctx.close();
        return;
      }
    }
    ctx.fireUserEventTriggered(evt);
  }

  private void sendTimeoutResponse(ChannelHandlerContext ctx) {
    ByteBuf buffer = ctx.alloc().buffer();
    buffer.writeBytes(TIMEOUT_READING_REQUEST.getBytes());
    HttpResponse rejection = new DefaultFullHttpResponse(protocolVersion, REQUEST_TIMEOUT, buffer);
    rejection.headers().set(CONNECTION, HttpHeaderValues.CLOSE);
    rejection.headers().set(CONTENT_LENGTH, TIMEOUT_READING_REQUEST.length());
    ctx.writeAndFlush(rejection).addListener(ChannelFutureListener.CLOSE);
  }
}
