package com.sap.cds.services.impl.cds;

import static com.sap.cds.services.impl.cds.TypedCqnServiceInvocationHandler.getCdsName;

import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Proxy;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;

import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.annotations.VisibleForTesting;
import com.sap.cds.services.Service;
import com.sap.cds.services.cds.CqnService;
import com.sap.cds.services.impl.ServiceSPI;
import com.sap.cds.services.impl.utils.CdsServiceUtils;

@SuppressWarnings("unchecked")
public class TypedCqnServiceFactory {

	private static final Logger logger = LoggerFactory.getLogger(TypedCqnServiceFactory.class);

	private static final Map<String, Class<? extends CqnService>> generatedServices = new HashMap<>();

	static {
		try {
			Iterator<URL> resources = ClassLoader.getSystemResources("META-INF/cds4j-codegen/services.generated").asIterator();
			while (resources.hasNext()) {
				try (InputStream generatedInput = resources.next().openStream()) {
					if (generatedInput != null) {
						ClassLoader cl = Thread.currentThread().getContextClassLoader();
						for (String className : IOUtils.readLines(generatedInput, StandardCharsets.UTF_8)) {
							try {
								Class<? extends CqnService> clazz = (Class<? extends CqnService>) cl.loadClass(className);
								getCdsName(clazz).ifPresent(cdsName -> generatedServices.put(cdsName, clazz));
							} catch (ClassNotFoundException e) {
								logger.warn("Could not load generated service class '{}'", className);
							}
						}
					}
				}
			}
		} catch (IOException e) {
			logger.warn("Could not load list of generated service classes", e);
		}
	}

	private TypedCqnServiceFactory() {
		// avoid instances
	}

	public static Service createProxyIfAvailable(Service service) {
		if (service instanceof AbstractCdsDefinedService cdsDefinedService) {
			String serviceName = cdsDefinedService.getDefinition().getQualifiedName();
			// create a proxy if there is a generated POJOs for this service
			if (generatedServices.containsKey(serviceName)) {
				Service proxy = createProxy(generatedServices.get(serviceName), cdsDefinedService);
				ServiceSPI serviceSPI = CdsServiceUtils.getServiceSPI(cdsDefinedService);
				if (serviceSPI != null) {
					serviceSPI.setDelegator(proxy);
				}
				return proxy;
			}
		}
		return service;
	}

	@VisibleForTesting
	static <T extends CqnService> T createProxy(Class<T> type, CqnService service) {
		Set<Class<?>> serviceInterfaces = getAllInterfaces(service.getClass());
		Class<?> exactlyMatchingProxy = null;
		for (Class<?> proxyCandidate : type.getDeclaredClasses()) {
			if (proxyCandidate.isInterface()) {
				Set<Class<?>> candidateInterfaces = getAllInterfaces(proxyCandidate);
				candidateInterfaces.remove(type);
				if (candidateInterfaces.containsAll(serviceInterfaces) && serviceInterfaces.containsAll(candidateInterfaces)) {
					exactlyMatchingProxy = proxyCandidate;
					break;
				}
			}
		}

		String proxyName;
		Set<Class<?>> proxyInterfaces = new LinkedHashSet<>();
		if (exactlyMatchingProxy != null) {
			proxyInterfaces.add(exactlyMatchingProxy);
			proxyName = exactlyMatchingProxy.getName();
		} else {
			// not expected, only as a fallback
			proxyInterfaces.addAll(serviceInterfaces);
			proxyInterfaces.add(type);
			proxyName = type.getName();
		}

		logger.debug("Wrapped service {} with generated interface {}", service.getName(), proxyName);
		return (T) Proxy.newProxyInstance(type.getClassLoader(), proxyInterfaces.toArray(new Class<?>[0]),
				new TypedCqnServiceInvocationHandler(service, type));
	}

	private static Set<Class<?>> getAllInterfaces(Class<?> clazz) {
		Set<Class<?>> interfaces = new LinkedHashSet<>();
		do {
			for (Class<?> i : clazz.getInterfaces()) {
				interfaces.add(i);
				interfaces.addAll(getAllInterfaces(i));
			}
			clazz = clazz.getSuperclass();
		} while (clazz != null);
		return interfaces;
	}

}
