/************************************************************************
 * © 2019-2022 SAP SE or an SAP affiliate company. All rights reserved. *
 ************************************************************************/
package com.sap.cds.ql.impl;

import static com.google.common.collect.Lists.reverse;
import static com.sap.cds.util.CdsModelUtils.isCascading;
import static com.sap.cds.util.CdsModelUtils.isReverseAssociation;
import static com.sap.cds.util.CdsModelUtils.CascadeType.INSERT;
import static com.sap.cds.util.DataUtils.isFkUpdate;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import com.sap.cds.SessionContext;
import com.sap.cds.ql.CdsDataException;
import com.sap.cds.ql.Insert;
import com.sap.cds.reflect.CdsAssociationType;
import com.sap.cds.reflect.CdsElement;
import com.sap.cds.reflect.CdsEntity;
import com.sap.cds.reflect.CdsModel;
import com.sap.cds.reflect.CdsStructuredType;
import com.sap.cds.util.CdsModelUtils;
import com.sap.cds.util.OnConditionAnalyzer;

public class DeepInsertSplitter {

	private final CdsEntity entity;
	private final Map<String, List<Map<String, Object>>> insertEntries = new LinkedHashMap<>();
	private final SessionContext sessionContext;
	private final Map<String, Object> hints;

	public DeepInsertSplitter(CdsModel model, String entityName, SessionContext sessionContext) {
		this(model.getEntity(entityName), sessionContext, Collections.emptyMap());
	}

	public DeepInsertSplitter(CdsEntity entity, SessionContext sessionContext, Map<String, Object> hints) {
		this.entity = entity;
		this.sessionContext = sessionContext;
		this.hints = hints;
	}

	public List<Insert> split(List<Map<String, Object>> entries) {
		entries.forEach(entry -> flattenEntry(entity, entry, true));

		return computeInserts();
	}

	public List<Insert> split(Map<String, Object> entry) {
		flattenEntry(entity, entry, true);

		return computeInserts();
	}

	private List<Insert> computeInserts() {
		List<Insert> inserts = insertEntries.entrySet().stream().map(e -> {
			String entity = e.getKey();
			return Insert.into(entity).entries(e.getValue()).hints(hints);
		}).collect(toList());

		return reverse(inserts);
	}

	private Map<String, Object> flattenEntry(CdsEntity entity, Map<String, Object> entry, boolean addInserts) {
		Map<String, Object> flatEntry = new HashMap<>();
		List<CdsElement> toFlatten = new ArrayList<>();
		// first write all flat entries and keep elements to be flattened
		entry.forEach((k, v) -> {
			Optional<CdsElement> element = entity.findElement(k);
			if (element.map(e -> e.getType().isAssociation() || e.getType().isStructured()).orElse(false)) {
				if (v != null) {
					toFlatten.add(element.get());
				}
			} else {
				flatEntry.put(k, v);
			}
		});

		// flatten associations and structured types
		for (CdsElement element : toFlatten) {
			Object value = entry.get(element.getName());
			if (element.getType().isAssociation()) {
				if (isReverseAssociation(element)) {
					handleReverseAssociation(entity, element, entry, value, addInserts);
				} else {
					handleForwardAssociation(flatEntry, entry, entity.getName(), value, element, addInserts);
				}
			} else if (element.getType().isStructured()) {
				flatEntry.put(element.getName(), collectStructElement(value, element, addInserts));
			}
		}

		if (addInserts) {
			insertEntries.computeIfAbsent(entity.getQualifiedName(), k -> new ArrayList<>()).add(flatEntry);
		}

		return flatEntry;
	}

	private void handleForwardAssociation(Map<String, Object> flatEntry, Map<String, Object> original,
			String entityName, Object value, CdsElement association, boolean addInserts) {
		Map<String, Object> targetValues = asMap(entityName, association.getName(), value);

		CdsAssociationType assocType = association.getType();
		Map<String, Object> flatTargetValues = flattenEntry(assocType.getTarget(), targetValues, false);
		Map<String, Object> fkValues = new OnConditionAnalyzer(association, false, sessionContext).getFkValues(flatTargetValues, false);
		// override flat FK with struct value
		flatEntry.putAll(fkValues);
		// for result
		fkValues.entrySet().stream()
			.filter(e -> original.containsKey(e.getKey()))
			.forEach(e -> original.put(e.getKey(), e.getValue()));

		if (cascadeInsert(association)) {
			assertInputDataContainsKeys(association, targetValues);
			flattenEntry(assocType.getTarget(), targetValues, addInserts);
		} else { // remove non-FK values from result
			Set<String> assocKeys = CdsModelUtils.assocKeys(association);
			targetValues.keySet().retainAll(assocKeys);
		}
	}

	@SuppressWarnings("unchecked")
	private Map<String, Object> collectStructElement(Object value, CdsElement element, boolean addInserts) {
		Map<String, Object> flatEntry = new HashMap<>();
		Map<String, Object> valueMap = (Map<String, Object>) value;
		CdsStructuredType structElement = element.getType().as(CdsStructuredType.class);
		structElement.nonAssociationElements().filter(f -> valueMap.keySet().contains(f.getName())).forEach(e -> {
			if (e.getType().isStructured()) {
				flatEntry.put(e.getName(), collectStructElement(valueMap.get(e.getName()), e, addInserts));
			} else {
				flatEntry.put(e.getName(), valueMap.get(e.getName()));
			}
		});
		// structured fks have priority and are written last
		structElement.associations().filter(f -> valueMap.keySet().contains(f.getName())).forEach(assoc ->
			handleForwardAssociation(flatEntry, valueMap, element.getDeclaringType().getName(), valueMap.get(assoc.getName()), assoc, addInserts)
		);
		return flatEntry;
	}

	private void handleReverseAssociation(CdsEntity entity, CdsElement association, Map<String, Object> entry,
			Object targetVal, boolean addInserts) {
		if (cascadeInsert(association)) {
			CdsAssociationType assoc = association.getType();
			Map<String, Object> fkValues = new OnConditionAnalyzer(association, true, sessionContext)
					.getFkValues(entry);
			List<Map<String, Object>> targetValues = asList(entity.getName(), association.getName(), targetVal);
			for (Map<String, Object> child : targetValues) {
				child.putAll(fkValues);
				flattenEntry(assoc.getTarget(), child, addInserts);
			}
		} else if (CdsModelUtils.isSingleValued(association.getType())) {
			throw new UnsupportedOperationException("Cannot set reference " + entity.getQualifiedName() + "."
					+ association.getName() + ". Reverse associations are not supported.");
		}
	}

	private static boolean cascadeInsert(CdsElement element) {
		return isCascading(INSERT, element);
	}

	private void assertInputDataContainsKeys(CdsElement association, Map<String, Object> targetValues) {
		CdsAssociationType assoc = association.getType();
		Set<String> keyNames = CdsModelUtils.concreteKeyNames(assoc.getTarget());
		if (!targetValues.keySet().containsAll(keyNames)) {
			if (isFkUpdate(association, targetValues, sessionContext)) {
				return; // update of FK targeting non-key element
			}
			throw new CdsDataException(
					"Data set " + targetValues.keySet() + " for association " + association.getQualifiedName()
							+ " does not contain values for all target entity keys " + keyNames + ".");
		}
	}

	@SuppressWarnings("unchecked")
	private static Map<String, Object> asMap(String entityName, String associationName, Object value) {
		try {
			return (Map<String, Object>) value;
		} catch (ClassCastException ex) {
			throw badValue(entityName, associationName, ex);
		}
	}

	@SuppressWarnings("unchecked")
	private static List<Map<String, Object>> asList(String entityName, String associationName, Object value) {
		if (value instanceof List) {
			return (List<Map<String, Object>>) value;
		}
		if (value instanceof Map) {
			return singletonList((Map<String, Object>) value);
		}
		throw badValue(entityName, associationName, null);
	}

	private static RuntimeException badValue(String entityName, String associationName, RuntimeException cause) {
		return new CdsDataException("Unexpected value: Entity '" + entityName
				+ "' contains unexpected value for the association '" + associationName + "'. ", cause);
	}
}
