/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.test.netty.impl.client;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.sameInstance;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.api.lifecycle.InitialisationException;
import org.mule.runtime.api.scheduler.Scheduler;
import org.mule.runtime.api.scheduler.SchedulerService;
import org.mule.runtime.api.tls.TlsContextFactory;
import org.mule.runtime.http.api.Http1ProtocolConfig;
import org.mule.runtime.http.api.Http2ProtocolConfig;
import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.client.HttpClientConfiguration;
import org.mule.service.http.netty.impl.client.HttpClientConnectionManager;
import org.mule.service.http.test.common.AbstractHttpTestCase;

import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.function.Supplier;

import javax.net.ssl.SSLContext;

import org.junit.Before;
import org.junit.Test;

public class HttpClientConnectionManagerTestCase extends AbstractHttpTestCase {

  private HttpClientConnectionManager clientConnectionManager;

  private Supplier<HttpClientConfiguration> configurationSupplier;

  @Before
  public void setUp() {
    clientConnectionManager = new HttpClientConnectionManager(null);

    configurationSupplier = mock(Supplier.class);
    HttpClientConfiguration mockConfiguration = mock(HttpClientConfiguration.class);
    when(mockConfiguration.getHttp1ProtocolConfig()).thenReturn(new Http1ProtocolConfig(true));
    when(mockConfiguration.getHttp2ProtocolConfig()).thenReturn(new Http2ProtocolConfig(false));
    doReturn(mockConfiguration).when(configurationSupplier).get();
  }

  @Test
  public void getOrCreateClientReturnsSameInstanceIfCalledTwiceWithSameName() {
    HttpClient firstCallClient = clientConnectionManager.getOrCreateClient("ClientName", configurationSupplier);
    HttpClient secondCallClient = clientConnectionManager.getOrCreateClient("ClientName", configurationSupplier);

    assertThat(secondCallClient, is(sameInstance(firstCallClient)));
  }

  @Test
  public void getOrCreateClientCallsConfigurationSupplierOnlyOnceIfCalledTwiceWithSameName() {
    clientConnectionManager.getOrCreateClient("ClientName", configurationSupplier);
    clientConnectionManager.getOrCreateClient("ClientName", configurationSupplier);

    verify(configurationSupplier, times(1)).get();
  }

  @Test
  public void getOrCreateClientReturnsDifferentInstancesIfCalledTwiceWithDifferentNames() {
    HttpClient firstCallClient = clientConnectionManager.getOrCreateClient("SomeClientName", configurationSupplier);
    HttpClient secondCallClient = clientConnectionManager.getOrCreateClient("OtherClientName", configurationSupplier);

    assertThat(secondCallClient, is(not(sameInstance(firstCallClient))));
  }

  @Test
  public void getOrCreateClientCallsConfigurationSupplierTwiceIfCalledTwiceWithDifferentNames() {
    clientConnectionManager.getOrCreateClient("SomeClientName", configurationSupplier);
    clientConnectionManager.getOrCreateClient("OtherClientName", configurationSupplier);

    verify(configurationSupplier, times(2)).get();
  }

  @Test
  public void clientNameCantBeNull() {
    var thrown =
        assertThrows(IllegalArgumentException.class,
                     () -> clientConnectionManager.getOrCreateClient(null, configurationSupplier));
    assertThat(thrown.getMessage(), containsString("Client name can't be null"));

  }

  @Test
  public void suppliedConfigurationCantBeNull() {
    when(configurationSupplier.get()).thenReturn(null);

    var thrown =
        assertThrows(MuleRuntimeException.class,
                     () -> clientConnectionManager.getOrCreateClient("SomeName", configurationSupplier));
    assertThat(thrown.getCause(), instanceOf(IllegalArgumentException.class));
    assertThat(thrown.getMessage(),
               containsString("java.lang.IllegalArgumentException: Client configuration can't be null"));
  }

  @Test
  public void testGetOrCreateClientWithTlsConfiguration() throws Exception {
    configurationSupplier = spy(Supplier.class);
    HttpClientConfiguration mockConfiguration = mock(HttpClientConfiguration.class);
    doReturn(mockConfiguration).when(configurationSupplier).get();
    TlsContextFactory tlsContextFactory = buildSetupTlsContextFactory(new String[] {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"},
                                                                      new String[] {"TLSv1.2", "TLSv1.3"});
    when(mockConfiguration.getTlsContextFactory()).thenReturn(tlsContextFactory);
    when(mockConfiguration.getHttp1ProtocolConfig()).thenReturn(new Http1ProtocolConfig(true));
    when(mockConfiguration.getHttp2ProtocolConfig()).thenReturn(new Http2ProtocolConfig(false));

    HttpClient client = clientConnectionManager.getOrCreateClient("ClientName", configurationSupplier);

    assertNotNull(client);
    verify(configurationSupplier, times(1)).get();
  }

  @Test
  public void testValidClientConfiguration() throws Exception {
    TlsContextFactory tlsContextFactory = buildSetupTlsContextFactory(new String[] {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"},
                                                                      new String[] {"TLSv1.2", "TLSv1.3"});
    HttpClientConfiguration clientConfig = mock(HttpClientConfiguration.class);
    when(clientConfig.getTlsContextFactory()).thenReturn(tlsContextFactory);
    when(clientConfig.getHttp1ProtocolConfig()).thenReturn(new Http1ProtocolConfig(true));
    when(clientConfig.getHttp2ProtocolConfig()).thenReturn(new Http2ProtocolConfig(false));

    HttpClientConnectionManager manager = new HttpClientConnectionManager(null);
    HttpClient client = manager.getOrCreateClient("TestClient", () -> clientConfig);

    assertNotNull(client);
    verify(clientConfig, times(1)).getTlsContextFactory();
    verify(tlsContextFactory, times(1)).getEnabledCipherSuites();
    verify(tlsContextFactory, times(1)).getEnabledProtocols();
  }

  @Test
  public void testNullConfiguration() {
    HttpClientConnectionManager manager = new HttpClientConnectionManager(null);
    assertThrows(MuleRuntimeException.class, () -> manager.getOrCreateClient("TestClient", () -> null));
  }

  @Test
  public void testSslContextCreationException() throws Exception {
    TlsContextFactory tlsContextFactory = mock(TlsContextFactory.class);
    when(tlsContextFactory.createSslContext()).thenThrow(new KeyManagementException("SSL Context Error"));
    HttpClientConfiguration clientConfig = mock(HttpClientConfiguration.class);
    when(clientConfig.getTlsContextFactory()).thenReturn(tlsContextFactory);
    when(clientConfig.getHttp1ProtocolConfig()).thenReturn(new Http1ProtocolConfig(true));
    when(clientConfig.getHttp2ProtocolConfig()).thenReturn(new Http2ProtocolConfig(false));

    HttpClientConnectionManager manager = new HttpClientConnectionManager(null);

    assertThrows(MuleRuntimeException.class, () -> manager.getOrCreateClient("TestClient", () -> clientConfig));
  }

  @Test
  public void testSchedulerCreation() throws InitialisationException {
    SchedulerService schedulerService = mock(SchedulerService.class);
    when(schedulerService.customScheduler(any(), anyInt())).thenReturn(mock(Scheduler.class));

    HttpClientConnectionManager manager = new HttpClientConnectionManager(schedulerService);
    manager.initialise();

    HttpClientConfiguration clientConfig = mock(HttpClientConfiguration.class);
    when(clientConfig.getHttp1ProtocolConfig()).thenReturn(new Http1ProtocolConfig(true));
    when(clientConfig.getHttp2ProtocolConfig()).thenReturn(new Http2ProtocolConfig(false));
    HttpClient client = manager.getOrCreateClient("TestClient", () -> clientConfig);
    assertNotNull(client);

    // The Scheduler will be created on start.
    client.start();
    verify(schedulerService).customScheduler(any(), anyInt());
  }

  @Test
  public void testCipherSuitePreservedWhenNotSupported() throws Exception {
    SSLContext realSSLContext = SSLContext.getInstance("TLS");
    realSSLContext.init(null, null, new SecureRandom());
    TlsContextFactory tlsContextFactory =
        buildSetupTlsContextFactory(new String[] {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}, new String[] {"TLSv1.2", "TLSv1.3"});
    HttpClientConfiguration clientConfig = mock(HttpClientConfiguration.class);
    when(clientConfig.getTlsContextFactory()).thenReturn(tlsContextFactory);
    when(clientConfig.getHttp1ProtocolConfig()).thenReturn(new Http1ProtocolConfig(true));
    when(clientConfig.getHttp2ProtocolConfig()).thenReturn(new Http2ProtocolConfig(false));

    TlsContextFactory serverTlsContextFactory =
        buildSetupTlsContextFactory(new String[] {"NonSupportedCipherSuite"}, new String[] {"NonSupportedProtocol"});
    HttpClientConfiguration serverConfig = mock(HttpClientConfiguration.class);
    when(serverConfig.getTlsContextFactory()).thenReturn(serverTlsContextFactory);
    when(serverConfig.getHttp1ProtocolConfig()).thenReturn(new Http1ProtocolConfig(true));
    when(serverConfig.getHttp2ProtocolConfig()).thenReturn(new Http2ProtocolConfig(false));

    HttpClient client = clientConnectionManager.getOrCreateClient("ClientName", () -> clientConfig);
    HttpClient server = clientConnectionManager.getOrCreateClient("ServerName", () -> serverConfig);

    assertNotNull(client);
    assertNotNull(server);
    assertArrayEquals(new String[] {"NonSupportedCipherSuite"}, serverConfig.getTlsContextFactory().getEnabledCipherSuites());
    assertArrayEquals(new String[] {"NonSupportedProtocol"}, serverConfig.getTlsContextFactory().getEnabledProtocols());
    verify(serverConfig.getTlsContextFactory(), times(2)).getEnabledCipherSuites();
    verify(serverConfig.getTlsContextFactory(), times(2)).getEnabledProtocols();
  }

  @Test
  public void testMultipleClientsWithDifferentConfigurations() throws Exception {
    TlsContextFactory clientTlsContextFactory1 =
        buildSetupTlsContextFactory(new String[] {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}, new String[] {"TLSv1.2"});
    HttpClientConfiguration clientConfig1 = mock(HttpClientConfiguration.class);
    when(clientConfig1.getTlsContextFactory()).thenReturn(clientTlsContextFactory1);
    when(clientConfig1.getHttp1ProtocolConfig()).thenReturn(new Http1ProtocolConfig(true));
    when(clientConfig1.getHttp2ProtocolConfig()).thenReturn(new Http2ProtocolConfig(false));

    TlsContextFactory clientTlsContextFactory2 =
        buildSetupTlsContextFactory(new String[] {"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"}, new String[] {"TLSv1.3"});
    HttpClientConfiguration clientConfig2 = mock(HttpClientConfiguration.class);
    when(clientConfig2.getTlsContextFactory()).thenReturn(clientTlsContextFactory2);
    when(clientConfig2.getHttp1ProtocolConfig()).thenReturn(new Http1ProtocolConfig(true));
    when(clientConfig2.getHttp2ProtocolConfig()).thenReturn(new Http2ProtocolConfig(false));

    HttpClientConnectionManager manager = new HttpClientConnectionManager(null);
    HttpClient client1 = manager.getOrCreateClient("Client1", () -> clientConfig1);
    HttpClient client2 = manager.getOrCreateClient("Client2", () -> clientConfig2);

    assertNotNull(client1);
    assertNotNull(client2);
    assertNotSame(client1, client2);
    verify(clientConfig1).getTlsContextFactory();
    verify(clientTlsContextFactory1).getEnabledCipherSuites();
    verify(clientTlsContextFactory1).getEnabledProtocols();
    verify(clientConfig2).getTlsContextFactory();
    verify(clientTlsContextFactory2).getEnabledCipherSuites();
    verify(clientTlsContextFactory2).getEnabledProtocols();
  }

  private TlsContextFactory buildSetupTlsContextFactory(String[] cipherSuites, String[] protocols)
      throws NoSuchAlgorithmException, KeyManagementException {
    SSLContext realSSLContext = SSLContext.getInstance("TLS");
    realSSLContext.init(null, null, new SecureRandom());
    TlsContextFactory mockTlsContextFactory = mock(TlsContextFactory.class);
    when(mockTlsContextFactory.getEnabledCipherSuites()).thenReturn(cipherSuites);
    when(mockTlsContextFactory.getEnabledProtocols()).thenReturn(protocols);
    when(mockTlsContextFactory.createSslContext()).thenReturn(realSSLContext);
    return mockTlsContextFactory;
  }
}
