/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.client.proxy;

import static io.netty.handler.codec.http.HttpHeaderNames.PROXY_AUTHENTICATE;
import static io.netty.handler.codec.http.HttpHeaderNames.PROXY_AUTHORIZATION;

import org.mule.runtime.http.api.client.proxy.ProxyConfig;
import org.mule.service.http.netty.impl.message.HttpResponseCreator;
import org.mule.service.http.netty.impl.streaming.BlockingBidirectionalStream;

import java.io.IOException;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.SocketAddress;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.LastHttpContent;

/**
 * Handler that establishes a connection with a <a href="https://datatracker.ietf.org/doc/html/rfc7230#page-10">message-forwarding
 * HTTP proxy agent</a>.
 * <p>
 * HTTP users who need to establish a blind forwarding proxy tunnel using
 * <a href="https://datatracker.ietf.org/doc/html/rfc7231#section-4.3.6">HTTP/1.1 CONNECT</a> request should use
 * {@link BlindTunnelingProxyClientHandler} instead.
 */
public class MessageForwardingProxyClientHandler extends ChannelDuplexHandler {

  private final InetSocketAddress proxyAddress;
  private final ProxyAuthenticator proxyAuthenticator;

  private HttpRequest originalRequest = null;
  private String lastProxyAuthHeader = null;
  private HttpResponse lastResponse = null;
  private boolean isDanceFinished = false;
  private boolean suppressChannelReadComplete = false;
  private BlockingBidirectionalStream responseStream;
  private OutputStream responseOutput;

  private static final String AUTH_CUSTOM = "custom";

  public MessageForwardingProxyClientHandler(String proxyHost, int proxyPort, ProxyAuthenticator proxyAuthenticator) {
    this.proxyAuthenticator = proxyAuthenticator;
    this.proxyAddress = new InetSocketAddress(proxyHost, proxyPort);
  }

  public MessageForwardingProxyClientHandler(ProxyConfig proxyConfig) {
    this(proxyConfig.getHost(), proxyConfig.getPort(), new ProxyAuthenticator(proxyConfig));
  }

  @Override
  public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
    if (isDanceFinished) {
      // From now on, the messages will be forwarded to the next handler (the rest of the client pipeline).
      suppressChannelReadComplete = false;
      ctx.fireChannelRead(msg);
      return;
    }

    suppressChannelReadComplete = true;
    if (msg instanceof HttpResponse response) {
      if (response.status().code() == 200) {
        isDanceFinished = true;
        ctx.fireChannelRead(msg);
        return;
      }

      this.lastProxyAuthHeader = extractServerProxyAuthHeader(response);
      if (this.lastProxyAuthHeader != null) {
        this.lastResponse = response;
        // initialize response stream for body capture
        responseStream = new BlockingBidirectionalStream();
        responseOutput = responseStream.getOutputStream();
        return;
      }
    }

    // capture body content
    if (msg instanceof HttpContent content && responseOutput != null) {
      ByteBuf contentBuf = content.content();
      if (contentBuf.isReadable()) {
        try {
          byte[] bytes = new byte[contentBuf.readableBytes()];
          contentBuf.getBytes(contentBuf.readerIndex(), bytes);
          responseOutput.write(bytes);
        } catch (IOException e) {
          throw new RuntimeException("Failed to capture response body", e);
        }
      }
    }

    if (msg instanceof LastHttpContent && this.lastProxyAuthHeader != null) {
      // close response stream
      if (responseOutput != null) {
        responseOutput.close();
        responseOutput = null;
      }
      sendAuthenticatedRequestToProxy(ctx, ctx.newPromise());
      // reset for next potential auth round
      responseStream = null;
    }
  }

  private String extractServerProxyAuthHeader(HttpResponse response) {
    return response.headers().get(PROXY_AUTHENTICATE);
  }

  @Override
  public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise)
      throws Exception {
    // Connect to proxy instead...
    ctx.connect(proxyAddress, localAddress, promise);
  }

  @Override
  public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
    if (msg instanceof HttpRequest request) {
      this.originalRequest = request;
      sendAuthenticatedRequestToProxy(ctx, promise);
    } else {
      ctx.write(msg, promise);
    }
  }

  @Override
  public void channelReadComplete(ChannelHandlerContext ctx) {
    if (this.suppressChannelReadComplete) {
      this.suppressChannelReadComplete = false;
      if (!ctx.channel().config().isAutoRead()) {
        ctx.read();
      }
    } else {
      ctx.fireChannelReadComplete();
    }
  }

  private void sendAuthenticatedRequestToProxy(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
    ctx.writeAndFlush(addProxyAuthHeaderIfNeeded(originalRequest), promise);
  }

  private HttpRequest addProxyAuthHeaderIfNeeded(HttpRequest request) throws Exception {
    if (proxyAuthenticator.hasFinished()) {
      return request;
    }

    String proxyAuthorizationHeader = AUTH_CUSTOM.equals(proxyAuthenticator.getAuthScheme())
        ? proxyAuthenticator.getNextHeader(nettyToMuleResponse(lastResponse))
        : proxyAuthenticator.getNextHeader(lastProxyAuthHeader);

    if (proxyAuthorizationHeader != null) {
      request.headers().set(PROXY_AUTHORIZATION, proxyAuthorizationHeader);
    }

    if (proxyAuthenticator.hasFinished()) {
      isDanceFinished = true;
    }

    return request;
  }

  private org.mule.runtime.http.api.domain.message.response.HttpResponse nettyToMuleResponse(HttpResponse response) {
    if (response == null) {
      return null;
    }
    // provide the captured response body
    HttpResponseCreator httpResponseCreator = new HttpResponseCreator();
    return httpResponseCreator.create(response, responseStream != null ? responseStream.getInputStream() : null);
  }
}
