/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.server;

import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpServerExpectContinueHandler;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.util.ReferenceCountUtil;

/**
 * Same as {@link HttpServerExpectContinueHandler}, but it also rejects unsupported expectation header.
 */
public class MuleHttpServerExpectContinueHandler extends HttpServerExpectContinueHandler {

  @Override
  public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
    if (msg instanceof HttpRequest) {
      HttpRequest req = (HttpRequest) msg;
      if (isUnsupportedExpectation(req)) {
        // the expectation failed so we refuse the request.
        HttpResponse rejection = rejectResponse(req);
        ReferenceCountUtil.release(msg);
        ctx.writeAndFlush(rejection).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
        return;
      }
    }

    super.channelRead(ctx, msg);
  }

  private static boolean isUnsupportedExpectation(HttpMessage message) {
    // Copy of io.netty.handler.codec.http.HttpUtil#isUnsupportedExpectation.
    // We are not using it directly because it's package-private.
    if (!isExpectHeaderValid(message)) {
      return false;
    }

    final String expectValue = message.headers().get(HttpHeaderNames.EXPECT);
    return expectValue != null && !HttpHeaderValues.CONTINUE.toString().equalsIgnoreCase(expectValue);
  }

  private static boolean isExpectHeaderValid(final HttpMessage message) {
    // Copy of io.netty.handler.codec.http.HttpUtil#isExpectHeaderValid.
    // We are not using it directly because it's private.

    /*
     * Expect: 100-continue is for requests only and it works only on HTTP/1.1 or later. Note further that RFC 7231 section 5.1.1
     * says "A server that receives a 100-continue expectation in an HTTP/1.0 request MUST ignore that expectation."
     */
    return message instanceof HttpRequest &&
        message.protocolVersion().compareTo(HttpVersion.HTTP_1_1) >= 0;
  }
}
