/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.server;

import static io.netty.buffer.Unpooled.EMPTY_BUFFER;

import org.mule.runtime.api.util.MultiMap;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.runtime.http.api.server.async.HttpResponseReadyCallback;
import org.mule.runtime.http.api.server.async.ResponseStatusCallback;
import org.mule.runtime.http.api.sse.server.SseClient;
import org.mule.runtime.http.api.sse.server.SseClientConfig;

import java.io.IOException;
import java.io.InputStream;
import java.io.Writer;
import java.nio.charset.Charset;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.DefaultHttp2DataFrame;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.DefaultHttp2HeadersFrame;
import io.netty.handler.codec.http2.Http2FrameStream;
import io.netty.handler.codec.http2.Http2Headers;

public class NettyHttp2RequestReadyCallback implements HttpResponseReadyCallback {

  private final ChannelHandlerContext ctx;
  private final Http2FrameStream frameStream;

  public NettyHttp2RequestReadyCallback(ChannelHandlerContext ctx, Http2FrameStream stream) {
    this.ctx = ctx;
    this.frameStream = stream;
  }

  @Override
  public void responseReady(HttpResponse response, ResponseStatusCallback responseStatusCallback) {
    try {
      sendResponse(ctx, frameStream, response, responseStatusCallback);
    } catch (Exception e) {
      responseStatusCallback.onErrorSendingResponse(e);
    }
  }

  @Override
  public Writer startResponse(HttpResponse response, ResponseStatusCallback responseStatusCallback, Charset encoding) {
    Http2Headers headers = extractHeaders(response);
    ctx.write(new DefaultHttp2HeadersFrame(headers, false).stream(frameStream),
              promiseToCallback(responseStatusCallback, false));
    return new Writer() {

      @Override
      public void write(char[] cbuf, int off, int len) {
        ByteBuf content = ctx.alloc().buffer();
        content.writeCharSequence(new String(cbuf, off, len), encoding);
        ctx.write(new DefaultHttp2DataFrame(content, false).stream(frameStream),
                  promiseToCallback(responseStatusCallback, false));
      }

      @Override
      public void flush() {
        ctx.flush();
      }

      @Override
      public void close() {
        ctx.writeAndFlush(new DefaultHttp2DataFrame(EMPTY_BUFFER, true).stream(frameStream),
                          promiseToCallback(responseStatusCallback, true));
      }
    };
  }

  @Override
  public SseClient startSseResponse(SseClientConfig config) {
    // TODO: If SSE over HTTP/2 is needed, check if this works:
    // return new SseResponseStarter().startResponse(config, this);
    throw new UnsupportedOperationException("SSE is not supported over HTTP/2");
  }

  private void sendResponse(ChannelHandlerContext ctx, Http2FrameStream stream, HttpResponse response,
                            ResponseStatusCallback responseStatusCallback)
      throws IOException {
    Http2Headers headers = extractHeaders(response);
    ctx.write(new DefaultHttp2HeadersFrame(headers, false).stream(stream),
              promiseToCallback(responseStatusCallback, false));

    ByteBuf content = ctx.alloc().buffer();
    InputStream contentAsInputStream = response.getEntity().getContent();
    content.writeBytes(contentAsInputStream, contentAsInputStream.available());
    ctx.writeAndFlush(new DefaultHttp2DataFrame(content, true).stream(stream),
                      promiseToCallback(responseStatusCallback, true));
  }

  private static Http2Headers extractHeaders(HttpResponse response) {
    // Create HTTP/2 headers and copy all headers from the Mule response
    Http2Headers headers = new DefaultHttp2Headers().status(String.valueOf(response.getStatusCode()));

    // Copy all headers from the Mule response to HTTP/2 headers
    MultiMap<String, String> responseHeaders = response.getHeaders();
    for (String key : responseHeaders.keySet()) {
      for (String value : responseHeaders.getAll(key)) {
        headers.add(key, value);
      }
    }
    return headers;
  }

  private ChannelPromise promiseToCallback(ResponseStatusCallback responseStatusCallback, boolean isLast) {
    ChannelFutureListener futureListener = channelFuture -> {
      if (channelFuture.isSuccess()) {
        if (isLast) {
          responseStatusCallback.responseSendSuccessfully();
        }
      } else {
        responseStatusCallback.onErrorSendingResponse(channelFuture.cause());
      }
    };
    ChannelPromise channelPromise = ctx.newPromise();
    channelPromise.addListener(futureListener);
    return channelPromise;
  }
}
