/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.util;

import static io.netty.handler.codec.http.HttpMethod.GET;
import static io.netty.handler.codec.http.HttpResponseStatus.OK;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import static io.netty.util.AttributeKey.valueOf;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.mockito.Mockito.when;
import static org.mockito.junit.MockitoJUnit.rule;

import org.mule.tck.junit4.AbstractMuleTestCase;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.util.AttributeKey;
import io.netty.util.AttributeMap;
import io.netty.util.DefaultAttributeMap;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.junit.MockitoRule;
import reactor.netty.Connection;

public class HttpLoggingHandlerTestCase extends AbstractMuleTestCase {

  private static final AttributeKey<Connection> CONNECTION = valueOf("$CONNECTION");
  private static final String LEGACY_WIRE_LOGGING_CLASS_NAME = "org.mule.service.http.impl.service.HttpMessageLogger";

  @Rule
  public MockitoRule rule = rule();

  @Rule
  public NettyLoggerRule loggerRule = new NettyLoggerRule();

  @Mock
  private ChannelHandlerContext ctx;

  @Mock
  private Channel channel;

  @Mock
  private Connection connection;

  private HttpLoggingHandler loggingHandler;

  @Before
  public void setup() {
    AttributeMap attributeMap = new DefaultAttributeMap();
    attributeMap.attr(CONNECTION).set(connection);

    when(channel.hasAttr(CONNECTION)).thenReturn(true);
    when(channel.attr(CONNECTION)).thenReturn(attributeMap.attr(CONNECTION));
    when(channel.toString()).thenReturn("[id: 0x1234]");

    when(ctx.channel()).thenReturn(channel);

    loggingHandler = new HttpLoggingHandler();
  }

  @Test
  public void logChannelRead() throws Exception {
    HttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, GET, "https://www.salesforce.org/test", asByteBuf("Hello test!"));
    loggingHandler.channelRead(ctx, request);

    List<String> debugLogs = loggerRule.getDebugLogs(LEGACY_WIRE_LOGGING_CLASS_NAME);
    assertThat(debugLogs, contains("[1234] READ: 11B Hello test!"));
  }

  @Test
  public void logWrite() throws Exception {
    HttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, OK, asByteBuf("Hello test!"));
    loggingHandler.write(ctx, response, new DefaultChannelPromise(channel));

    List<String> debugLogs = loggerRule.getDebugLogs(LEGACY_WIRE_LOGGING_CLASS_NAME);
    assertThat(debugLogs, contains("[1234] WRITE: 11B Hello test!"));
  }

  @Test
  public void logFlush() throws Exception {
    loggingHandler.flush(ctx);

    List<String> debugLogs = loggerRule.getDebugLogs(LEGACY_WIRE_LOGGING_CLASS_NAME);
    assertThat(debugLogs, contains("[1234] FLUSH"));
  }

  @Test
  public void logConnect() throws Exception {
    SocketAddress localAddress = new InetSocketAddress(8080);
    SocketAddress remoteAddress = new InetSocketAddress(8081);
    ChannelPromise promise = new DefaultChannelPromise(channel);
    loggingHandler.connect(ctx, remoteAddress, localAddress, promise);

    List<String> debugLogs = loggerRule.getDebugLogs(LEGACY_WIRE_LOGGING_CLASS_NAME);
    assertThat(debugLogs, contains("[1234] CONNECT: 0.0.0.0/0.0.0.0:8081, 0.0.0.0/0.0.0.0:8080"));
  }

  @Test
  public void logClose() throws Exception {
    ChannelPromise promise = new DefaultChannelPromise(channel);
    loggingHandler.close(ctx, promise);

    List<String> debugLogs = loggerRule.getDebugLogs(LEGACY_WIRE_LOGGING_CLASS_NAME);
    assertThat(debugLogs, contains("[1234] CLOSE"));
  }

  private static ByteBuf asByteBuf(String string) {
    ByteBuf byteBuf = Unpooled.buffer(string.length());
    byteBuf.writeBytes(string.getBytes());
    return byteBuf;
  }
}
