/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.test.module.tls;

import static org.mule.runtime.api.config.MuleRuntimeFeature.HONOUR_INSECURE_TLS_CONFIGURATION;
import static org.mule.runtime.core.api.lifecycle.LifecycleUtils.initialiseIfNeeded;
import static org.mule.runtime.module.tls.internal.TlsConfiguration.DEFAULT_SECURITY_MODEL;
import static org.mule.runtime.module.tls.internal.TlsConfiguration.PROPERTIES_FILE_PATTERN;

import static java.util.Collections.emptyMap;

import static javax.net.ssl.TrustManagerFactory.getDefaultAlgorithm;
import static org.apache.commons.lang3.StringUtils.split;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.arrayContaining;
import static org.hamcrest.Matchers.arrayContainingInAnyOrder;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.isA;

import static org.junit.jupiter.api.Assertions.assertThrows;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import org.mule.runtime.api.config.FeatureFlaggingService;
import org.mule.runtime.api.lifecycle.InitialisationException;
import org.mule.runtime.api.tls.TlsContextFactory;
import org.mule.runtime.module.tls.internal.DefaultTlsContextFactory;
import org.mule.tck.junit4.AbstractMuleTestCase;

import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.stream.Stream;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

public class DefaultTlsContextFactoryTestCase extends AbstractMuleTestCase {

  @BeforeAll
  public static void createTlsPropertiesFile() throws Exception {

    PrintWriter writer = new PrintWriter(getTlsPropertiesFile(), "UTF-8");
    writer.println("enabledCipherSuites=" + getFileEnabledCipherSuites());
    writer.println("enabledProtocols=" + getFileEnabledProtocols());
    writer.close();
  }

  @AfterAll
  public static void removeTlsPropertiesFile() {
    getTlsPropertiesFile().delete();
  }

  private static File getTlsPropertiesFile() {
    String path = DefaultTlsContextFactoryTestCase.class.getProtectionDomain().getCodeSource().getLocation().getPath();
    return new File(path, String.format(PROPERTIES_FILE_PATTERN, DEFAULT_SECURITY_MODEL));
  }

  public static String getFileEnabledProtocols() {
    return "TLSv1.2, TLSv1.3";
  }

  public static String getFileEnabledCipherSuites() {
    return "TLS_RSA_WITH_AES_128_GCM_SHA256, TLS_AES_128_GCM_SHA256";
  }

  @Test
  void failIfKeyStoreHasNoKey() throws Exception {
    DefaultTlsContextFactory tlsContextFactory = new DefaultTlsContextFactory(emptyMap());
    tlsContextFactory.setKeyStorePath("trustStore");
    tlsContextFactory.setKeyStorePassword("mulepassword");
    tlsContextFactory.setKeyPassword("mulepassword");
    var thrown = assertThrows(InitialisationException.class, () -> tlsContextFactory.initialise());
    assertThat(thrown.getCause().getCause(), isA(IllegalArgumentException.class));
    assertThat(thrown.getCause().getCause().getMessage(), containsString("No key entries found."));
  }

  @Test
  void failIfKeyStoreAliasIsNotAKey() throws Exception {
    DefaultTlsContextFactory tlsContextFactory = new DefaultTlsContextFactory(emptyMap());
    tlsContextFactory.setKeyStorePath("serverKeystore");
    tlsContextFactory.setKeyAlias("muleclient");
    tlsContextFactory.setKeyStorePassword("mulepassword");
    tlsContextFactory.setKeyPassword("mulepassword");
    var thrown = assertThrows(InitialisationException.class, () -> tlsContextFactory.initialise());
    assertThat(thrown.getCause().getCause(), isA(IllegalArgumentException.class));
    assertThat(thrown.getCause().getCause().getMessage(), containsString("Keystore entry for alias 'muleclient' is not a key."));
  }

  @Test
  void failIfTrustStoreIsNonexistent() throws Exception {
    DefaultTlsContextFactory tlsContextFactory = new DefaultTlsContextFactory(emptyMap());
    var thrown = assertThrows(IOException.class, () -> tlsContextFactory.setTrustStorePath("non-existent-trust-store"));
    assertThat(thrown.getMessage(), containsString("Resource non-existent-trust-store could not be found"));
  }

  @Test
  void insecureTrustStoreShouldNotBeConfiguredIfFFIsEnabled() throws IOException, InitialisationException {
    assertTrue(getFeatureFlaggingService().isEnabled(HONOUR_INSECURE_TLS_CONFIGURATION));
    DefaultTlsContextFactory tlsContextFactory = new DefaultTlsContextFactory(emptyMap(), getFeatureFlaggingService());
    tlsContextFactory.setTrustStorePath("trustStore");
    tlsContextFactory.setTrustStoreInsecure(true);
    assertFalse(tlsContextFactory.isTrustStoreConfigured());
  }

  @Test
  void insecureTrustStoreShouldBeConfiguredIfFFIsDisabled() throws IOException, InitialisationException {
    assertFalse(getFeatureFlaggingServiceWithFFDisabled().isEnabled(HONOUR_INSECURE_TLS_CONFIGURATION));
    DefaultTlsContextFactory tlsContextFactory =
        new DefaultTlsContextFactory(emptyMap(), getFeatureFlaggingServiceWithFFDisabled());
    tlsContextFactory.setTrustStorePath("trustStore");
    tlsContextFactory.setTrustStoreInsecure(true);
    assertTrue(tlsContextFactory.isTrustStoreConfigured());
  }

  @Test
  void useConfigFileIfDefaultProtocolsAndCipherSuites() throws Exception {
    DefaultTlsContextFactory tlsContextFactory = new DefaultTlsContextFactory(emptyMap());
    tlsContextFactory.setEnabledCipherSuites("DEFAULT");
    tlsContextFactory.setEnabledProtocols("default");
    tlsContextFactory.initialise();

    assertThat(tlsContextFactory.getEnabledCipherSuites(),
               is(Stream.of(split(getFileEnabledCipherSuites(), ",")).map(String::trim).toArray(String[]::new)));
    assertThat(tlsContextFactory.getEnabledProtocols(),
               is(Stream.of(split(getFileEnabledProtocols(), ",")).map(String::trim).toArray(String[]::new)));
  }

  @Test
  void trustStoreAlgorithmInTlsContextIsDefaultTrustManagerAlgorithm() {
    DefaultTlsContextFactory tlsContextFactory = new DefaultTlsContextFactory(emptyMap());
    assertThat(tlsContextFactory.getTrustManagerAlgorithm(), equalTo(getDefaultAlgorithm()));
  }

  @Test
  void overrideConfigFile() throws Exception {
    DefaultTlsContextFactory tlsContextFactory = new DefaultTlsContextFactory(emptyMap());
    tlsContextFactory.setEnabledCipherSuites("TLS_RSA_WITH_AES_128_GCM_SHA256");
    tlsContextFactory.setEnabledProtocols("TLSv1.2");
    tlsContextFactory.initialise();

    String[] enabledCipherSuites = tlsContextFactory.getEnabledCipherSuites();
    assertThat(enabledCipherSuites.length, is(1));
    assertThat(enabledCipherSuites, is(arrayContaining("TLS_RSA_WITH_AES_128_GCM_SHA256")));

    String[] enabledProtocols = tlsContextFactory.getEnabledProtocols();
    assertThat(enabledProtocols.length, is(1));
    assertThat(enabledProtocols, is(arrayContaining("TLSv1.2")));
  }

  @Test
  void failIfProtocolsDoNotMatchConfigFile() throws Exception {
    String[] invalidProtocols = {"TLSv1", "SSLv3", "TLSv1.1"};
    for (String protocol : invalidProtocols) {
      DefaultTlsContextFactory tlsContextFactory = new DefaultTlsContextFactory(emptyMap());
      tlsContextFactory.setEnabledProtocols(protocol);
      var thrown = assertThrows(InitialisationException.class, () -> tlsContextFactory.initialise());
      assertThat(thrown.getMessage(), containsString("protocols are invalid"));
    }
  }

  @Test
  void failIfCipherSuitesDoNotMatchConfigFile() throws Exception {
    DefaultTlsContextFactory tlsContextFactory = new DefaultTlsContextFactory(emptyMap());
    tlsContextFactory.setEnabledCipherSuites("SSL_DHE_DSS_WITH_3DES_EDE_CBC_SHA");
    var thrown = assertThrows(InitialisationException.class, () -> tlsContextFactory.initialise());
    assertThat(thrown.getMessage(), containsString("cipher suites are invalid"));
  }

  @Test
  void cannotMutateEnabledProtocols() throws InitialisationException {
    TlsContextFactory tlsContextFactory = new DefaultTlsContextFactory(emptyMap());
    initialiseIfNeeded(tlsContextFactory);
    tlsContextFactory.getEnabledProtocols()[0] = "TLSv1";
    assertThat(tlsContextFactory.getEnabledProtocols(), arrayWithSize(2));
    assertThat(tlsContextFactory.getEnabledProtocols(), arrayContaining("TLSv1.2", "TLSv1.3"));
  }

  @Test
  void cannotMutateEnabledCipherSuites() throws InitialisationException {
    TlsContextFactory tlsContextFactory = new DefaultTlsContextFactory(emptyMap());
    initialiseIfNeeded(tlsContextFactory);
    tlsContextFactory.getEnabledCipherSuites()[0] = "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256";
    assertThat(tlsContextFactory.getEnabledCipherSuites(), arrayWithSize(2));
    assertThat(tlsContextFactory.getEnabledCipherSuites(),
               arrayContaining("TLS_RSA_WITH_AES_128_GCM_SHA256", "TLS_AES_128_GCM_SHA256"));
  }

  @Test
  void defaultIncludesTls13Ciphers() throws Exception {
    defaultIncludesDefaultTlsVersionCiphers("TLSv1.3");
  }

  private FeatureFlaggingService getFeatureFlaggingServiceWithFFDisabled() {
    return feature -> !feature.equals(HONOUR_INSECURE_TLS_CONFIGURATION);
  }

  private void defaultIncludesDefaultTlsVersionCiphers(String sslVersion)
      throws InitialisationException, KeyManagementException, NoSuchAlgorithmException {
    DefaultTlsContextFactory tlsContextFactory = new DefaultTlsContextFactory(emptyMap());
    tlsContextFactory.initialise();
    SSLSocketFactory defaultFactory = tlsContextFactory.createSslContext().getSocketFactory();
    SSLContext tlsContext = SSLContext.getInstance(sslVersion);
    tlsContext.init(null, null, null);
    SSLSocketFactory tlsFactory = tlsContext.getSocketFactory();

    assertThat(defaultFactory.getDefaultCipherSuites(), arrayContainingInAnyOrder(tlsFactory.getDefaultCipherSuites()));
  }

}
