/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.tls.sni;

import static org.glassfish.grizzly.Grizzly.DEFAULT_ATTRIBUTE_BUILDER;

import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.concurrent.atomic.AtomicReference;

import javax.net.ssl.SSLContext;

import org.apache.commons.lang3.StringUtils;
import org.glassfish.grizzly.attributes.Attribute;
import org.glassfish.grizzly.filterchain.FilterChainBuilder;
import org.glassfish.grizzly.http.server.AddOn;
import org.glassfish.grizzly.http.server.HttpServer;
import org.glassfish.grizzly.http.server.NetworkListener;
import org.glassfish.grizzly.sni.SNIConfig;
import org.glassfish.grizzly.sni.SNIFilter;
import org.glassfish.grizzly.ssl.SSLBaseFilter;
import org.glassfish.grizzly.ssl.SSLContextConfigurator;
import org.glassfish.grizzly.ssl.SSLEngineConfigurator;

/**
 * Embedded HTTPS server that fails to serve if SNI extension is not honored
 */
public class HttpsTestServerEnforcingSNI {

  private final int port;

  private HttpServer httpServer;
  private SniAddOn sniAddOn;

  public HttpsTestServerEnforcingSNI(int port) {
    this.port = port;
  }

  protected void startServer() throws IOException, URISyntaxException {
    NetworkListener networkListener = new NetworkListener("sample-listener", "localhost", port);

    SSLContext sslContext = createSSLContextConfigurator().createSSLContext(false);
    var sslServerEngineConfig = new SSLEngineConfigurator(sslContext, false, false, false);
    networkListener.setSSLEngineConfig(sslServerEngineConfig);

    httpServer = new HttpServer();
    httpServer.addListener(networkListener);
    networkListener.setSecure(true);
    sniAddOn = new SniAddOn(sslServerEngineConfig);
    networkListener.registerAddOn(sniAddOn);
    httpServer.start();
  }

  protected void stopServer() {
    sniAddOn.stop();
    httpServer.shutdownNow();
  }

  private SSLContextConfigurator createSSLContextConfigurator() throws URISyntaxException {
    SSLContextConfigurator sslContextConfigurator = new SSLContextConfigurator();
    ClassLoader cl = HttpsRequesterSniTestCase.class.getClassLoader();

    URL cacertsUrl = cl.getResource("tls/sni-server-truststore.jks");
    if (cacertsUrl != null) {
      sslContextConfigurator.setTrustStoreFile(new File(cacertsUrl.toURI()).getPath());
      sslContextConfigurator.setTrustStorePass("changeit");
    }

    URL keystoreUrl = cl.getResource("tls/sni-server-keystore.jks");
    if (keystoreUrl != null) {
      sslContextConfigurator.setKeyStoreFile(new File(keystoreUrl.toURI()).getPath());
      sslContextConfigurator.setKeyStorePass("changeit");
      sslContextConfigurator.setKeyPass("changeit");
    }

    return sslContextConfigurator;
  }

  public String getHostname() {
    return sniAddOn.getHostname();
  }

  private static class SniAddOn implements AddOn {

    private final SSLEngineConfigurator sslServerEngineConfig;
    private final AtomicReference<String> sniHostname = new AtomicReference<>();

    public SniAddOn(SSLEngineConfigurator sslServerEngineConfig) {
      this.sslServerEngineConfig = sslServerEngineConfig;
    }

    @Override
    public void setup(NetworkListener networkListener, FilterChainBuilder builder) {
      // replace SSLFilter (if any) with SNIFilter
      final int idx = builder.indexOfType(SSLBaseFilter.class);
      if (idx != -1) {
        builder.set(idx, getSniFilter());
      }
    }

    private SNIFilter getSniFilter() {
      final Attribute<String> sniHostAttr = DEFAULT_ATTRIBUTE_BUILDER.createAttribute("sni-host-attr");

      SNIFilter sniFilter = new SNIFilter();
      sniFilter.setServerSSLConfigResolver((connection, hostname) -> {
        sniHostAttr.set(connection, hostname);
        sniHostname.set(hostname);
        if (StringUtils.isEmpty(hostname)) {
          throw new IllegalArgumentException("SNI Has not been sent");
        }
        return SNIConfig.newServerConfig(sslServerEngineConfig);
      });
      return sniFilter;
    }

    public void stop() {
      sniHostname.set(StringUtils.EMPTY);
    }

    public String getHostname() {
      return sniHostname.get();
    }
  }
}
