/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.server;

import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION;
import static io.netty.handler.codec.http2.Http2CodecUtil.HTTP_UPGRADE_PROTOCOL_NAME;
import static io.netty.util.AsciiString.containsAllContentEqualsIgnoreCase;
import static io.netty.util.AsciiString.contentEquals;

import org.mule.service.http.netty.impl.server.RejectFailedUpgradeHandler.FailedHttpUpgrade;
import org.mule.service.http.netty.impl.server.util.HttpListenerRegistry;

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.Executor;

import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpServerUpgradeHandler;
import io.netty.handler.codec.http2.Http2FrameCodecBuilder;
import io.netty.handler.codec.http2.Http2MultiplexHandler;
import io.netty.handler.codec.http2.Http2ServerUpgradeCodec;

class UpgradeToHttp2CleartextCodecFactory implements HttpServerUpgradeHandler.UpgradeCodecFactory {

  private final HttpListenerRegistry httpListenerRegistry;
  private final Executor ioExecutor;

  public UpgradeToHttp2CleartextCodecFactory(HttpListenerRegistry httpListenerRegistry, Executor ioExecutor) {
    this.httpListenerRegistry = httpListenerRegistry;
    this.ioExecutor = ioExecutor;
  }

  @Override
  public HttpServerUpgradeHandler.UpgradeCodec newUpgradeCodec(CharSequence protocol) {
    if (contentEquals(HTTP_UPGRADE_PROTOCOL_NAME, protocol)) {
      return new UpdateCodecWithErrorResponse(new Http2ServerUpgradeCodec(Http2FrameCodecBuilder.forServer()
          .build(), new Http2MultiplexHandler(new MultiplexerChannelInitializer(httpListenerRegistry, null, ioExecutor))));
    } else {
      return null;
    }
  }

  private static final class UpdateCodecWithErrorResponse implements HttpServerUpgradeHandler.UpgradeCodec {

    public static final String MISSING_CONNECTION_HEADER_VALUE =
        "HTTP Upgrade request to HTTP2 failed because Connection header doesn't contain required values";
    public static final String UPGRADE_FAILED = "HTTP Upgrade request to HTTP2 failed";
    private final HttpServerUpgradeHandler.UpgradeCodec delegate;

    private UpdateCodecWithErrorResponse(HttpServerUpgradeHandler.UpgradeCodec delegate) {
      this.delegate = delegate;
    }

    @Override
    public Collection<CharSequence> requiredUpgradeHeaders() {
      // Return empty list in order to skip this validation in netty, as we want to do it in the
      // method #prepareUpgradeResponse()
      return List.of();
    }

    @Override
    public boolean prepareUpgradeResponse(ChannelHandlerContext ctx, FullHttpRequest upgradeRequest, HttpHeaders upgradeHeaders) {
      if (!containsRequiredUpgradeHeaders(upgradeRequest)) {
        ctx.fireUserEventTriggered(new FailedHttpUpgrade(upgradeRequest, MISSING_CONNECTION_HEADER_VALUE));
        return false;
      }

      if (!delegate.prepareUpgradeResponse(ctx, upgradeRequest, upgradeHeaders)) {
        ctx.fireUserEventTriggered(new FailedHttpUpgrade(upgradeRequest, UPGRADE_FAILED));
        return false;
      }

      return true;
    }

    @Override
    public void upgradeTo(ChannelHandlerContext ctx, FullHttpRequest upgradeRequest) {
      delegate.upgradeTo(ctx, upgradeRequest);
    }

    private static Collection<CharSequence> getConnectionHeaders(FullHttpRequest upgradeRequest) {
      var connectionHeaderValues = upgradeRequest.headers().getAll(CONNECTION);
      var connectionHeaders = new HashSet<CharSequence>(connectionHeaderValues.size() * 2);
      for (String connectionHeaderValue : connectionHeaderValues) {
        for (String splitValue : connectionHeaderValue.split(",")) {
          connectionHeaders.add(splitValue.trim());
        }
      }
      return connectionHeaders;
    }

    private boolean containsRequiredUpgradeHeaders(FullHttpRequest upgradeRequest) {
      var connectionHeaders = getConnectionHeaders(upgradeRequest);
      var requiredConnectionHeaders = delegate.requiredUpgradeHeaders();
      return containsAllContentEqualsIgnoreCase(connectionHeaders, requiredConnectionHeaders);
    }
  }
}
