/*
 * © 2023-2024 SAP SE or an SAP affiliate company. All rights reserved.
 */
package com.sap.cds.services.utils.cert;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import com.sap.cds.services.ErrorStatuses;
import com.sap.cds.services.ServiceException;
import com.sap.cds.services.request.RequestContext;
import com.sap.cds.services.utils.ErrorStatusException;
import com.sap.cloud.environment.servicebinding.api.ServiceBinding;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Certification validator for CAP internal usage. */
public class CertValidator {
  private final Logger logger = LoggerFactory.getLogger(CertValidator.class);

  private static final String CERTIFICATE_HEADER = "-----BEGIN CERTIFICATE-----";
  private static final String CERTIFICATE_HEADER_ENCODED = "-----BEGIN%20CERTIFICATE-----";
  private static final String CALLBACK_CERTIFICATE_ISSUER = "callback_certificate_issuer";
  private static final String CALLBACK_CERTIFICATE_SUBJECT = "callback_certificate_subject";
  private static final String X_509 = "X.509";

  private static final ObjectMapper mapper = new ObjectMapper();
  private static final String XFCC_HEADER_CERT_KEY = "cert=";

  private final CertificateFactory certFactory = getX509CertFactory();

  private final Map<String, Object> expectedIssuer;
  private final Map<String, Object> expectedSubject;
  private final String clientCertificateHeaderName;

  CertValidator(
      String clientCertificateHeaderName,
      Map<String, Object> expectedIssuer,
      Map<String, Object> expectedSubject) {
    this.clientCertificateHeaderName = clientCertificateHeaderName;
    this.expectedIssuer = expectedIssuer;
    this.expectedSubject = expectedSubject;
  }

  public static CertValidator create(String clientCertificateHeaderName, ServiceBinding binding) {
    try {
      Map<String, Object> expectedIssuer =
          jsonStrToFlatMap((String) binding.getCredentials().get(CALLBACK_CERTIFICATE_ISSUER));
      Map<String, Object> expectedSubject =
          jsonStrToFlatMap((String) binding.getCredentials().get(CALLBACK_CERTIFICATE_SUBJECT));
      return new CertValidator(clientCertificateHeaderName, expectedIssuer, expectedSubject);
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }

  public static CertValidator create(
      String clientCertificateHeaderName, String expectedIssuer, String expectedSubject) {
    return new CertValidator(
        clientCertificateHeaderName, strToFlatMap(expectedIssuer), strToFlatMap(expectedSubject));
  }

  private static CertificateFactory getX509CertFactory() {
    try {
      return CertificateFactory.getInstance(X_509);
    } catch (CertificateException e) {
      // should not happen as there is a provider for "X.509"
      throw new ServiceException("No X.509 certificate provider found", e);
    }
  }

  @VisibleForTesting
  X509Certificate getX509Certificate(String cert) throws Exception {
    ByteArrayInputStream is;
    if (cert.startsWith(CERTIFICATE_HEADER)) {
      logger.debug("Certificate starts with {}", CERTIFICATE_HEADER);
      is = new ByteArrayInputStream(cert.getBytes(StandardCharsets.UTF_8));
    } else if (cert.startsWith(CERTIFICATE_HEADER_ENCODED)) {
      logger.debug("Certificate starts with {}", CERTIFICATE_HEADER_ENCODED);
      // in some cases the certificate is URL encoded (e.g. istio)
      is =
          new ByteArrayInputStream(
              URLDecoder.decode(cert, StandardCharsets.UTF_8).getBytes(StandardCharsets.UTF_8));
    } else {
      logger.debug("Certificate is base64 encoded");
      is = new ByteArrayInputStream(Base64.getDecoder().decode(cert));
    }
    return (X509Certificate) certFactory.generateCertificate(is);
  }

  public void validateCertFromRequestContext(RequestContext requestContext) {
    String cert = requestContext.getParameterInfo().getHeader(clientCertificateHeaderName);
    if (cert != null) {
      // The header might contain the certificate in simple base64 encoding or following some XFCC
      // convention
      // (https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_conn_man/headers#x-forwarded-client-cert).
      String[] parts = cert.split(";");
      if (parts.length > 1) {
        cert =
            Arrays.stream(parts)
                .filter(p -> p.toLowerCase().startsWith(XFCC_HEADER_CERT_KEY))
                .map(p -> p.substring(XFCC_HEADER_CERT_KEY.length()).replace("\"", ""))
                .findFirst()
                .orElseThrow(
                    () -> {
                      logger.debug(
                          "No key 'cert' found in header '{}' that includes a certificate.",
                          clientCertificateHeaderName);
                      return new ErrorStatusException(ErrorStatuses.UNAUTHORIZED);
                    });
      }
      X509Certificate certificate;
      try {
        certificate = getX509Certificate(cert);
      } catch (Exception e) {
        logger.debug(
            "Header '{}' did not include a valid certificate.", clientCertificateHeaderName);
        throw new ErrorStatusException(ErrorStatuses.UNAUTHORIZED);
      }
      if (!isValidCertIssuer(certificate.getIssuerX500Principal().toString())
          || !isValidCertSubject(certificate.getSubjectX500Principal().toString())) {
        logger.debug(
            "Subject or issuer of certificate from header '{}' not as expected.",
            clientCertificateHeaderName);
        throw new ErrorStatusException(ErrorStatuses.FORBIDDEN);
      }
    } else {
      logger.debug("Request did not provide header '{}'.", clientCertificateHeaderName);
      throw new ErrorStatusException(ErrorStatuses.UNAUTHORIZED);
    }
  }

  @VisibleForTesting
  @SuppressWarnings("unchecked")
  boolean isValidCertSubject(String subjectFromCert) {
    boolean valid =
        subjectFromCert.contains("C=" + expectedSubject.get("C"))
            && subjectFromCert.contains("CN=" + expectedSubject.get("CN"))
            && ("*".equals(expectedSubject.get("L"))
                || subjectFromCert.contains("L=" + expectedSubject.get("L")))
            && subjectFromCert.contains("O=" + expectedSubject.get("O"));

    if (expectedSubject.get("OU").toString().startsWith("[")
        && expectedSubject.get("OU").toString().endsWith("]")) {
      List<String> items = (ArrayList<String>) expectedSubject.get("OU");
      for (String item : items) {
        valid = valid && subjectFromCert.contains("OU=" + item);
      }
    } else {
      valid = valid && subjectFromCert.contains("OU=" + expectedSubject.get("OU"));
    }
    return valid;
  }

  @VisibleForTesting
  boolean isValidCertIssuer(String issuerFromCert) {
    return issuerFromCert.contains("C=" + expectedIssuer.get("C"))
        && issuerFromCert.contains("OU=" + expectedIssuer.get("OU"))
        && issuerFromCert.contains("CN=" + expectedIssuer.get("CN"))
        && ("*".equals(expectedIssuer.get("L"))
            || issuerFromCert.contains("L=" + expectedIssuer.get("L")))
        && issuerFromCert.contains("O=" + expectedIssuer.get("O"));
  }

  @VisibleForTesting
  static Map<String, Object> jsonStrToFlatMap(String json) throws IOException {
    return mapper.readValue(json, new TypeReference<HashMap<String, Object>>() {});
  }

  @VisibleForTesting
  @SuppressWarnings("unchecked")
  static Map<String, Object> strToFlatMap(String input) {
    return Arrays.stream(input.split(","))
        .map(pair -> pair.split("="))
        .filter(keyValue -> keyValue.length == 2)
        .collect(
            Collectors.toMap(
                keyValue -> keyValue[0].trim(),
                keyValue -> keyValue[1].trim(),
                (existing, replacement) -> {
                  if (existing instanceof List) {
                    ((List<String>) existing).add((String) replacement);
                    return existing;
                  }
                  return new ArrayList<>(Arrays.asList((String) existing, (String) replacement));
                },
                HashMap::new));
  }
}
