/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.test.netty.impl.server;

import static org.mule.runtime.api.util.MuleSystemProperties.SYSTEM_PROPERTY_PREFIX;
import static org.mule.runtime.http.api.HttpConstants.HttpStatus.REQUEST_TOO_LONG;
import static org.mule.service.http.netty.impl.server.NettyHttp1ResponseReadyCallback.getMaxServerResponseHeaders;
import static org.mule.service.http.netty.impl.server.NettyHttp1ResponseReadyCallback.refreshMaxServerResponseHeaders;
import static org.mule.service.http.netty.impl.server.util.HttpListenerRegistry.getMaxServerRequestHeaders;
import static org.mule.service.http.netty.impl.server.util.HttpListenerRegistry.refreshMaxServerRequestHeaders;
import static org.mule.service.http.test.netty.utils.TestUtils.createServerSslContext;

import static java.lang.System.clearProperty;
import static java.lang.System.setProperty;
import static java.util.Collections.singleton;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertTrue;

import org.mule.runtime.api.util.MultiMap;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.runtime.http.api.server.HttpServer;
import org.mule.service.http.netty.impl.server.AcceptedConnectionChannelInitializer;
import org.mule.service.http.netty.impl.server.NettyHttpServer;
import org.mule.service.http.netty.impl.server.util.HttpListenerRegistry;
import org.mule.service.http.test.netty.tck.ExecutorRule;
import org.mule.service.http.test.netty.utils.TestHttp2RequestHandler;
import org.mule.service.http.test.netty.utils.client.TestSSLNettyClient;
import org.mule.tck.junit4.rule.DynamicPort;

import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.ExecutionException;

import io.netty.handler.ssl.SslContext;
import io.qameta.allure.Issue;
import org.apache.commons.io.IOUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;

public class NettyHttpServerHeaderLimitTestCase {

  @ClassRule
  public static ExecutorRule executorRule = new ExecutorRule();

  private HttpServer server;
  private static final int MAX_NUM_HEADERS_DEFAULT = 100;
  private static final String MAX_SERVER_REQUEST_HEADERS_KEY = SYSTEM_PROPERTY_PREFIX + "http.MAX_SERVER_REQUEST_HEADERS";
  private static final String MAX_SERVER_RESPONSE_HEADERS_KEY = SYSTEM_PROPERTY_PREFIX + "http.MAX_SERVER_RESPONSE_HEADERS";

  @Rule
  public DynamicPort serverPort = new DynamicPort("serverPort");

  @Rule
  public TestSSLNettyClient testClient = new TestSSLNettyClient("localhost", serverPort.getNumber(), executorRule.getExecutor());

  @Before
  public void setUp() throws Exception {
    HttpListenerRegistry listenerRegistry = new HttpListenerRegistry();
    SslContext serverSslContext = createServerSslContext();

    server = NettyHttpServer.builder()
        .withName("test-server")
        .withServerAddress(new InetSocketAddress(serverPort.getNumber()))
        .withHttpListenerRegistry(listenerRegistry)
        .withSslContext(serverSslContext)
        .withShutdownTimeout(() -> 5000L)
        .withClientChannelHandler(new AcceptedConnectionChannelInitializer(listenerRegistry, "test-server", true, 30000, 10000L,
                                                                           serverSslContext, executorRule.getExecutor()))
        .build();
    server.start();
    server.addRequestHandler("/path", new TestHttp2RequestHandler());
    server.addRequestHandler(singleton("GET"), "/only-get", new TestHttp2RequestHandler());
  }

  @After
  public void tearDown() {
    server.stop().dispose();
  }

  @Issue("W-15642768")
  @Test
  public void testMaxServerRequestHeaders() throws Exception {
    assertThat(getMaxServerRequestHeaders(), equalTo(MAX_NUM_HEADERS_DEFAULT));
    setProperty(MAX_SERVER_REQUEST_HEADERS_KEY, "5");
    refreshMaxServerRequestHeaders();

    MultiMap.StringMultiMap headers = new MultiMap.StringMultiMap();
    headers.put("testheader1", "testvalue1");
    headers.put("testheader2", "testvalue2");
    headers.put("testheader3", "testvalue3");
    headers.put("testheader4", "testvalue4");
    headers.put("testheader5", "testvalue5");
    headers.put("testheader6", "testvalue6");

    HttpResponse response = testClient.sendGet("/path", headers);
    String responseAsString = IOUtils.toString(response.getEntity().getContent(), StandardCharsets.UTF_8);
    assertThat(getMaxServerRequestHeaders(), equalTo(5));
    assertThat(response.getStatusCode(), is(REQUEST_TOO_LONG.getStatusCode()));
    assertThat(responseAsString, containsString("Request entity too large"));

    clearProperty(MAX_SERVER_REQUEST_HEADERS_KEY);
    refreshMaxServerRequestHeaders();
  }

  @Issue("W-15642768")
  @Test
  public void testMaxServerResponseHeaders() throws Exception {
    assertThat(getMaxServerResponseHeaders(), equalTo(MAX_NUM_HEADERS_DEFAULT));
    // as the response headers will be p{[content-length=[9], connection=[close]]}
    setProperty(MAX_SERVER_RESPONSE_HEADERS_KEY, "1");
    refreshMaxServerResponseHeaders();

    assertThat(getMaxServerResponseHeaders(), equalTo(1));

    try {
      testClient.sendGet("/path");
    } catch (ExecutionException e) {
      assertTrue(e.getCause() instanceof IllegalArgumentException);
      assertThat(e.getCause().getMessage(), containsString("Exceeded max server response headers limit"));
    }

    clearProperty(MAX_SERVER_RESPONSE_HEADERS_KEY);
    refreshMaxServerResponseHeaders();
  }
}
