/**************************************************************************
 * (C) 2019-2024 SAP SE or an SAP affiliate company. All rights reserved. *
 **************************************************************************/
package com.sap.cds.adapter.sms;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import com.sap.cds.services.ErrorStatuses;
import com.sap.cds.services.request.RequestContext;
import com.sap.cds.services.runtime.CdsRuntime;
import com.sap.cds.services.utils.ErrorStatusException;
import com.sap.cloud.environment.servicebinding.api.ServiceBinding;

/**
 * Certification validator used by {@link SmsProvisioningServlet}.
 */
class CertValidator {

	private static final String CERTIFICATE_HEADER = "-----BEGIN CERTIFICATE-----";
	private static final String CALLBACK_CERTIFICATE_ISSUER = "callback_certificate_issuer";
	private static final String CALLBACK_CERTIFICATE_SUBJECT = "callback_certificate_subject";
	private static final String X_509 = "X.509";

	private final ObjectMapper mapper = new ObjectMapper();
	private final CertificateFactory certFactory = getX509CertFactory();

	private final Map<String, Object> expectedIssuer;
	private final Map<String, Object> expectedSubject;
	private final String clientCertificateHeaderName;

	CertValidator(String clientCertificateHeaderName, String expectedIssuer, String expectedSubject) {
		this.clientCertificateHeaderName = clientCertificateHeaderName;
		try {
			this.expectedIssuer = jsonStrToFlatMap(expectedIssuer);
			this.expectedSubject = jsonStrToFlatMap(expectedSubject);
		} catch (IOException e) {
			throw new RuntimeException(e);
		}
	}

	static CertValidator create(CdsRuntime runtime, ServiceBinding smsBinding) {
		String clientCertificateHeaderName = runtime.getEnvironment().getCdsProperties().getMultiTenancy()
				.getSubscriptionManager().getClientCertificateHeader();
		String expecedIssuer = (String) smsBinding.getCredentials().get(CALLBACK_CERTIFICATE_ISSUER);
		String expectedSubject = (String) smsBinding.getCredentials().get(CALLBACK_CERTIFICATE_SUBJECT);
		return new CertValidator(clientCertificateHeaderName, expecedIssuer, expectedSubject);
	}

	private static CertificateFactory getX509CertFactory() {
		try {
			return CertificateFactory.getInstance(X_509);
		} catch (CertificateException e) {
			// should not happen as there is a provider for "X.509"
			return null;
		}
	}

	@VisibleForTesting
	X509Certificate getX509Certificate(String cert) throws Exception {
		ByteArrayInputStream is;
		if ((cert.startsWith(CERTIFICATE_HEADER))) {
			is = new ByteArrayInputStream(cert.getBytes(StandardCharsets.UTF_8));
		} else {
			is = new ByteArrayInputStream(Base64.getDecoder().decode(cert));
		}
		return (X509Certificate) certFactory.generateCertificate(is);
	}

	void validateCertFromRequestContext(RequestContext requestContext) {
		String cert = requestContext.getParameterInfo().getHeader(clientCertificateHeaderName);
		if (cert != null) {
			X509Certificate certificate;
			try {
				certificate = getX509Certificate(cert);
			} catch (Exception e) {
				// request did not present a valid certificate
				// => 401
				throw new ErrorStatusException(ErrorStatuses.UNAUTHORIZED);
			}
			// request presented a certificate but not signed by a valid CA
			// => 403
			if (!isValidCertIssuer(certificate.getIssuerX500Principal().toString())
					|| !isValidCertSubject(certificate.getSubjectX500Principal().toString())) {
				throw new ErrorStatusException(ErrorStatuses.FORBIDDEN);
			}
		} else {
			// requested did not present any certificate
			// => 401
			throw new ErrorStatusException(ErrorStatuses.UNAUTHORIZED);
		}
	}

	@VisibleForTesting
	@SuppressWarnings("unchecked")
	boolean isValidCertSubject(String subjectFromCert) {
		boolean valid = subjectFromCert.contains("C=" + expectedSubject.get("C"))
				&& subjectFromCert.contains("CN=" + expectedSubject.get("CN"))
				&& ("*".equals(expectedSubject.get("L"))
				|| subjectFromCert.contains("L=" + expectedSubject.get("L")))
				&& subjectFromCert.contains("O=" + expectedSubject.get("O"));

		if (expectedSubject.get("OU").toString().startsWith("[")
				&& expectedSubject.get("OU").toString().endsWith("]")) {
			List<String> items = (ArrayList<String>) expectedSubject.get("OU");
			for (String item : items) {
				valid = valid && subjectFromCert.contains("OU=" + item);
			}
		} else {
			valid = valid && subjectFromCert.contains("OU=" + expectedSubject.get("OU"));
		}
		return valid;
	}

	@VisibleForTesting
	boolean isValidCertIssuer(String issuerFromCert) {
		return issuerFromCert.contains("C=" + expectedIssuer.get("C"))
				&& issuerFromCert.contains("OU=" + expectedIssuer.get("OU"))
				&& issuerFromCert.contains("CN=" + expectedIssuer.get("CN"))
				&& ("*".equals(expectedIssuer.get("L"))
				|| issuerFromCert.contains("L=" + expectedIssuer.get("L")))
				&& issuerFromCert.contains("O=" + expectedIssuer.get("O"));
	}

	@VisibleForTesting
	Map<String, Object> jsonStrToFlatMap(String json) throws IOException {
		return mapper.readValue(json, new TypeReference<HashMap<String, Object>>() {
		});
	}

}
