/*
 * © 2024-2025 SAP SE or an SAP affiliate company. All rights reserved.
 */
package com.sap.cds.feature.mt.lib.runtime;

import com.sap.cds.feature.mt.lib.subscription.DataSourceInfo;
import com.sap.db.jdbc.ConnectionSapDB;
import com.sap.db.util.Base64Utils;
import com.sap.db.util.StringUtils;
import com.sap.db.util.security.X509Authentication;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.function.BiConsumer;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HanaConnectionProvider implements ConnectionProvider {
  private static Logger logger = LoggerFactory.getLogger(HanaConnectionProvider.class);

  @Override
  public Connection getConnection(
      String tenantId, DataSourceAndInfo dataSourceAndInfo, DataSourceInfo libContainerInfo)
      throws SQLException {
    DataSourceInfo info = dataSourceAndInfo.getDataSourceInfo();
    Connection connection = dataSourceAndInfo.getDataSource().getConnection(); // NOSONAR
    connection.setClientInfo("LOCALE", null); // workaround for HANA bug
    try {
      if ("X509".equalsIgnoreCase(dataSourceAndInfo.getDataSourceInfo().getCredentialType())) {
        doUserSwitch(
            connection,
            info,
            i -> i.getUserName(),
            (stmt, i) -> {
              try {
                String x509 =
                    X509Authentication.resolveX509FileName(
                        info.getPassword()); // this line is probably not needed as it only returns
                // the same value as info.getPassword()
                String certificates = Base64Utils.getCertificates(x509);
                String proof =
                    X509Authentication.generateJWTTokenForProof(
                        connection.unwrap(ConnectionSapDB.class), x509);
                String sql =
                    "CONNECT WITH X509 CERTIFICATE "
                        + StringUtils.quote(certificates, '\'', '\u0000', true, true)
                        + " PROOF "
                        + StringUtils.quote(proof, '\'', '\u0000', true, true);
                stmt.execute(sql);
              } catch (SQLException e) {
                throw new RuntimeException(e);
              }
            });
      } else {
        doUserSwitch(
            connection,
            info,
            i -> i.getUser(),
            (stmt, i) -> {
              try {
                stmt.execute(
                    "connect %s password \"%s\" ".formatted(info.getUser(), info.getPassword()));
              } catch (SQLException e) {
                throw new RuntimeException(e);
              }
            });
      }
    } catch (RuntimeException e) {
      connection.close();
      if (e.getCause() instanceof SQLException) {
        throw (SQLException) e.getCause();
      }
      throw e;
    }

    return connection;
  }

  private void doUserSwitch(
      Connection connection,
      DataSourceInfo info,
      Function<DataSourceInfo, String> userProvider,
      BiConsumer<Statement, DataSourceInfo> consumer)
      throws SQLException {
    String targetUser = userProvider.apply(info);
    if (targetUser.equals(connection.getMetaData().getUserName())) {
      logger.debug("Connection belongs to requested user, no reconnect needed");
      return;
    }
    try (Statement statement = connection.createStatement()) {
      long start = System.nanoTime();
      consumer.accept(statement, info);

      var userName = connection.getMetaData().getUserName();
      if (!targetUser.equals(userName)) {
        throw new SQLException(
            "User should be %s after reconnect but is %s".formatted(targetUser, userName));
      }
      long end = System.nanoTime();
      logger.debug(
          "Reconnect connection for user {} took {} nanoseconds.", targetUser, (end - start));

      logger.debug("Set schema to {}", info.getSchema());
      connection.setSchema(info.getSchema());
      assert connection.getSchema().equals(info.getSchema());
    } catch (SQLException sqlException) {
      connection.close();
      throw sqlException;
    }
  }
}
