/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.client;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.sameInstance;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertThrows;
import static org.junit.rules.ExpectedException.none;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.api.lifecycle.InitialisationException;
import org.mule.runtime.api.scheduler.Scheduler;
import org.mule.runtime.api.scheduler.SchedulerService;
import org.mule.runtime.api.tls.TlsContextFactory;
import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.client.HttpClientConfiguration;
import org.mule.tck.junit4.AbstractMuleTestCase;

import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.function.Supplier;

import javax.net.ssl.SSLContext;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

public class HttpClientConnectionManagerTestCase extends AbstractMuleTestCase {

  // This port won't be bound during this test, so it doesn't need to be a DynamicPort. However, we let it with a low value so
  // that if at some moment in the future it's needed to be bound, it will fail and the maintainer will use a DynamicPort rule
  // here.
  private static final int TEST_PORT = 80;

  private HttpClientConnectionManager clientConnectionManager;

  private Supplier<HttpClientConfiguration> configurationSupplier;

  @Rule
  public ExpectedException expected = none();

  @Before
  public void setUp() {
    clientConnectionManager = new HttpClientConnectionManager();

    configurationSupplier = spy(mock(Supplier.class));
    HttpClientConfiguration mockConfiguration = spy(mock(HttpClientConfiguration.class));
    // when(mockConfiguration.getServerAddress()).thenReturn(new InetSocketAddress(TEST_PORT));
    when(configurationSupplier.get()).thenReturn(mockConfiguration);
  }

  @Test
  public void getOrCreateClientReturnsSameInstanceIfCalledTwiceWithSameName() {
    HttpClient firstCallClient = clientConnectionManager.getOrCreateClient("ClientName", configurationSupplier);
    HttpClient secondCallClient = clientConnectionManager.getOrCreateClient("ClientName", configurationSupplier);

    assertThat(secondCallClient, is(sameInstance(firstCallClient)));
  }

  @Test
  public void getOrCreateClientCallsConfigurationSupplierOnlyOnceIfCalledTwiceWithSameName() {
    clientConnectionManager.getOrCreateClient("ClientName", configurationSupplier);
    clientConnectionManager.getOrCreateClient("ClientName", configurationSupplier);

    verify(configurationSupplier, times(1)).get();
  }

  @Test
  public void getOrCreateClientReturnsDifferentInstancesIfCalledTwiceWithDifferentNames() {
    HttpClient firstCallClient = clientConnectionManager.getOrCreateClient("SomeClientName", configurationSupplier);
    HttpClient secondCallClient = clientConnectionManager.getOrCreateClient("OtherClientName", configurationSupplier);

    assertThat(secondCallClient, is(not(sameInstance(firstCallClient))));
  }

  @Test
  public void getOrCreateClientCallsConfigurationSupplierTwiceIfCalledTwiceWithDifferentNames() {
    clientConnectionManager.getOrCreateClient("SomeClientName", configurationSupplier);
    clientConnectionManager.getOrCreateClient("OtherClientName", configurationSupplier);

    verify(configurationSupplier, times(2)).get();
  }

  @Test
  public void clientNameCantBeNull() {
    expected.expect(IllegalArgumentException.class);
    expected.expectMessage("Client name can't be null");
    clientConnectionManager.getOrCreateClient(null, configurationSupplier);
  }

  @Test
  public void suppliedConfigurationCantBeNull() {
    when(configurationSupplier.get()).thenReturn(null);

    expected.expect(MuleRuntimeException.class);
    expected.expectCause(instanceOf(IllegalArgumentException.class));
    expected.expectMessage("java.lang.IllegalArgumentException: Client configuration can't be null");
    clientConnectionManager.getOrCreateClient("SomeName", configurationSupplier);
  }

  @Test
  public void testGetOrCreateClientWithTlsConfiguration() throws Exception {

    configurationSupplier = spy(Supplier.class);
    HttpClientConfiguration mockConfiguration = mock(HttpClientConfiguration.class);
    when(configurationSupplier.get()).thenReturn(mockConfiguration);
    TlsContextFactory tlsContextFactory = buildSetupTlsContextFactory(new String[] {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"},
                                                                      new String[] {"TLSv1.2", "TLSv1.3"});
    when(mockConfiguration.getTlsContextFactory()).thenReturn(tlsContextFactory);

    HttpClient client = clientConnectionManager.getOrCreateClient("ClientName", configurationSupplier);

    assertNotNull(client);
    verify(configurationSupplier, times(1)).get();
  }

  @Test
  public void testValidClientConfiguration() throws Exception {

    TlsContextFactory tlsContextFactory = buildSetupTlsContextFactory(new String[] {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"},
                                                                      new String[] {"TLSv1.2", "TLSv1.3"});
    HttpClientConfiguration clientConfig = mock(HttpClientConfiguration.class);
    when(clientConfig.getTlsContextFactory()).thenReturn(tlsContextFactory);

    HttpClientConnectionManager manager = new HttpClientConnectionManager();
    HttpClient client = manager.getOrCreateClient("TestClient", () -> clientConfig);

    assertNotNull(client);
    verify(clientConfig, times(1)).getTlsContextFactory();
    verify(tlsContextFactory, times(1)).getEnabledCipherSuites();
    verify(tlsContextFactory, times(1)).getEnabledProtocols();
  }

  @Test
  public void testNullConfiguration() {
    HttpClientConnectionManager manager = new HttpClientConnectionManager();
    assertThrows(MuleRuntimeException.class, () -> manager.getOrCreateClient("TestClient", () -> null));
  }

  @Test
  public void testSslContextCreationException() throws Exception {

    TlsContextFactory tlsContextFactory = mock(TlsContextFactory.class);
    when(tlsContextFactory.createSslContext()).thenThrow(new KeyManagementException("SSL Context Error"));
    HttpClientConfiguration clientConfig = mock(HttpClientConfiguration.class);
    when(clientConfig.getTlsContextFactory()).thenReturn(tlsContextFactory);

    HttpClientConnectionManager manager = new HttpClientConnectionManager();

    assertThrows(MuleRuntimeException.class, () -> manager.getOrCreateClient("TestClient", () -> clientConfig));
  }

  @Test
  public void testSchedulerCreation() throws InitialisationException {
    SchedulerService schedulerService = mock(SchedulerService.class);
    when(schedulerService.customScheduler(any(), anyInt())).thenReturn(mock(Scheduler.class));

    HttpClientConnectionManager manager = new HttpClientConnectionManager(schedulerService);
    manager.initialise();

    HttpClientConfiguration clientConfig = mock(HttpClientConfiguration.class);
    HttpClient client = manager.getOrCreateClient("TestClient", () -> clientConfig);

    assertNotNull(client);
    verify(schedulerService).customScheduler(any(), anyInt());
  }

  @Test
  public void testCipherSuitePreservedWhenNotSupported() throws Exception {

    SSLContext realSSLContext = SSLContext.getInstance("TLS");
    realSSLContext.init(null, null, new SecureRandom());
    TlsContextFactory tlsContextFactory =
        buildSetupTlsContextFactory(new String[] {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}, new String[] {"TLSv1.2", "TLSv1.3"});
    HttpClientConfiguration clientConfig = mock(HttpClientConfiguration.class);
    when(clientConfig.getTlsContextFactory()).thenReturn(tlsContextFactory);
    TlsContextFactory serverTlsContextFactory =
        buildSetupTlsContextFactory(new String[] {"NonSupportedCipherSuite"}, new String[] {"NonSupportedProtocol"});
    HttpClientConfiguration serverConfig = mock(HttpClientConfiguration.class);
    when(serverConfig.getTlsContextFactory()).thenReturn(serverTlsContextFactory);

    HttpClient client = clientConnectionManager.getOrCreateClient("ClientName", () -> clientConfig);
    HttpClient server = clientConnectionManager.getOrCreateClient("ServerName", () -> serverConfig);

    assertNotNull(client);
    assertNotNull(server);
    assertArrayEquals(new String[] {"NonSupportedCipherSuite"}, serverConfig.getTlsContextFactory().getEnabledCipherSuites());
    assertArrayEquals(new String[] {"NonSupportedProtocol"}, serverConfig.getTlsContextFactory().getEnabledProtocols());
    verify(serverConfig.getTlsContextFactory(), times(2)).getEnabledCipherSuites();
    verify(serverConfig.getTlsContextFactory(), times(2)).getEnabledProtocols();
  }

  @Test
  public void testMultipleClientsWithDifferentConfigurations() throws Exception {

    TlsContextFactory clientTlsContextFactory1 =
        buildSetupTlsContextFactory(new String[] {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}, new String[] {"TLSv1.2"});
    HttpClientConfiguration clientConfig1 = mock(HttpClientConfiguration.class);
    when(clientConfig1.getTlsContextFactory()).thenReturn(clientTlsContextFactory1);
    TlsContextFactory clientTlsContextFactory2 =
        buildSetupTlsContextFactory(new String[] {"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"}, new String[] {"TLSv1.3"});
    HttpClientConfiguration clientConfig2 = mock(HttpClientConfiguration.class);
    when(clientConfig2.getTlsContextFactory()).thenReturn(clientTlsContextFactory2);

    HttpClientConnectionManager manager = new HttpClientConnectionManager();
    HttpClient client1 = manager.getOrCreateClient("Client1", () -> clientConfig1);
    HttpClient client2 = manager.getOrCreateClient("Client2", () -> clientConfig2);

    assertNotNull(client1);
    assertNotNull(client2);
    assertNotSame(client1, client2);
    verify(clientConfig1).getTlsContextFactory();
    verify(clientTlsContextFactory1).getEnabledCipherSuites();
    verify(clientTlsContextFactory1).getEnabledProtocols();
    verify(clientConfig2).getTlsContextFactory();
    verify(clientTlsContextFactory2).getEnabledCipherSuites();
    verify(clientTlsContextFactory2).getEnabledProtocols();
  }

  private TlsContextFactory buildSetupTlsContextFactory(String[] cipherSuites, String[] protocols)
      throws NoSuchAlgorithmException, KeyManagementException {
    SSLContext realSSLContext = SSLContext.getInstance("TLS");
    realSSLContext.init(null, null, new SecureRandom());
    TlsContextFactory mockTlsContextFactory = mock(TlsContextFactory.class);
    when(mockTlsContextFactory.getEnabledCipherSuites()).thenReturn(cipherSuites);
    when(mockTlsContextFactory.getEnabledProtocols()).thenReturn(protocols);
    when(mockTlsContextFactory.createSslContext()).thenReturn(realSSLContext);
    return mockTlsContextFactory;
  }
}
