/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.client.proxy;

import static io.netty.handler.codec.http.HttpHeaderNames.PROXY_AUTHENTICATE;
import static io.netty.handler.codec.http.HttpHeaderNames.PROXY_AUTHORIZATION;

import org.mule.runtime.http.api.client.proxy.ProxyConfig;

import java.net.InetSocketAddress;
import java.net.SocketAddress;

import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.LastHttpContent;

/**
 * Handler that establishes a connection with a <a href="https://datatracker.ietf.org/doc/html/rfc7230#page-10">message-forwarding
 * HTTP proxy agent</a>.
 * <p>
 * HTTP users who need to establish a blind forwarding proxy tunnel using
 * <a href="https://datatracker.ietf.org/doc/html/rfc7231#section-4.3.6">HTTP/1.1 CONNECT</a> request should use
 * {@link BlindTunnelingProxyClientHandler} instead.
 */
public class MessageForwardingProxyClientHandler extends ChannelDuplexHandler {

  private final InetSocketAddress proxyAddress;
  private final ProxyAuthenticator proxyAuthenticator;

  private HttpRequest originalRequest = null;
  private String lastProxyAuthHeader = null;
  private boolean isDanceFinished = false;
  private boolean suppressChannelReadComplete = false;

  public MessageForwardingProxyClientHandler(String proxyHost, int proxyPort, ProxyAuthenticator proxyAuthenticator) {
    this.proxyAuthenticator = proxyAuthenticator;
    this.proxyAddress = new InetSocketAddress(proxyHost, proxyPort);
  }

  public MessageForwardingProxyClientHandler(ProxyConfig proxyConfig) {
    this(proxyConfig.getHost(), proxyConfig.getPort(), new ProxyAuthenticator(proxyConfig));
  }

  @Override
  public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
    if (isDanceFinished) {
      // From now on, the messages will be forwarded to the next handler (the rest of the client pipeline).
      suppressChannelReadComplete = false;
      ctx.fireChannelRead(msg);
      return;
    }

    suppressChannelReadComplete = true;
    if (msg instanceof HttpResponse) {
      HttpResponse response = (HttpResponse) msg;
      if (response.status().code() == 200) {
        isDanceFinished = true;
        ctx.fireChannelRead(msg);
        return;
      }

      this.lastProxyAuthHeader = extractServerProxyAuthHeader(response);
      if (this.lastProxyAuthHeader != null) {
        return;
      }
    }

    if (msg instanceof LastHttpContent && this.lastProxyAuthHeader != null) {
      sendAuthenticatedRequestToProxy(ctx, ctx.newPromise());
    }
  }

  private String extractServerProxyAuthHeader(HttpResponse response) {
    return response.headers().get(PROXY_AUTHENTICATE);
  }

  @Override
  public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise)
      throws Exception {
    // Connect to proxy instead...
    ctx.connect(proxyAddress, localAddress, promise);
  }

  @Override
  public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
    if (msg instanceof HttpRequest) {
      HttpRequest request = (HttpRequest) msg;
      this.originalRequest = request;
      sendAuthenticatedRequestToProxy(ctx, promise);
    } else {
      ctx.write(msg, promise);
    }
  }

  @Override
  public void channelReadComplete(ChannelHandlerContext ctx) {
    if (this.suppressChannelReadComplete) {
      this.suppressChannelReadComplete = false;
      if (!ctx.channel().config().isAutoRead()) {
        ctx.read();
      }
    } else {
      ctx.fireChannelReadComplete();
    }
  }

  private void sendAuthenticatedRequestToProxy(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
    ctx.writeAndFlush(addProxyAuthHeaderIfNeeded(originalRequest), promise);
  }

  private HttpRequest addProxyAuthHeaderIfNeeded(HttpRequest request) throws Exception {
    if (proxyAuthenticator.hasFinished()) {
      return request;
    }

    String proxyAuthorizationHeader = proxyAuthenticator.getNextHeader(lastProxyAuthHeader);
    if (proxyAuthorizationHeader != null) {
      request.headers().set(PROXY_AUTHORIZATION, proxyAuthorizationHeader);
    }

    if (proxyAuthenticator.hasFinished()) {
      isDanceFinished = true;
    }

    return request;
  }
}
