/************************************************************************
 * © 2021-2022 SAP SE or an SAP affiliate company. All rights reserved. *
 ************************************************************************/
package com.sap.cds.ql.impl;

import static com.sap.cds.impl.EntityCascader.cascadeDelete;
import static com.sap.cds.impl.EntityCascader.EntityKeys.keys;
import static com.sap.cds.impl.EntityCascader.EntityOperation.nop;
import static com.sap.cds.impl.EntityCascader.EntityOperation.root;
import static com.sap.cds.impl.RowImpl.row;
import static com.sap.cds.util.CdsModelUtils.isCascading;
import static com.sap.cds.util.CdsModelUtils.isReverseAssociation;
import static com.sap.cds.util.CdsModelUtils.isSingleValued;
import static com.sap.cds.util.CqnStatementUtils.hasInfixFilter;
import static com.sap.cds.util.DataUtils.generateUuidKeys;
import static com.sap.cds.util.DataUtils.hash;
import static com.sap.cds.util.DataUtils.isFkUpdate;
import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;

import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.sap.cds.CdsDataStore;
import com.sap.cds.Result;
import com.sap.cds.Row;
import com.sap.cds.SessionContext;
import com.sap.cds.impl.EntityCascader.EntityKeys;
import com.sap.cds.impl.EntityCascader.EntityOperation;
import com.sap.cds.impl.EntityCascader.EntityOperation.Operation;
import com.sap.cds.impl.EntityCascader.EntityOperations;
import com.sap.cds.ql.CQL;
import com.sap.cds.ql.CdsDataException;
import com.sap.cds.ql.Select;
import com.sap.cds.ql.StructuredType;
import com.sap.cds.ql.cqn.CqnSelect;
import com.sap.cds.ql.cqn.CqnSelectListItem;
import com.sap.cds.ql.cqn.CqnUpdate;
import com.sap.cds.reflect.CdsAssociationType;
import com.sap.cds.reflect.CdsElement;
import com.sap.cds.reflect.CdsEntity;
import com.sap.cds.reflect.impl.DraftAdapter;
import com.sap.cds.util.CdsModelUtils;
import com.sap.cds.util.CdsModelUtils.CascadeType;
import com.sap.cds.util.CqnStatementUtils;
import com.sap.cds.util.DataUtils;
import com.sap.cds.util.OnConditionAnalyzer;

public class DeepUpdateSplitter {
	private static final Logger logger = LoggerFactory.getLogger(DeepUpdateSplitter.class);
	private final CdsDataStore dataStore;
	private final SessionContext session;
	private CdsEntity entity;
	private EntityOperations operations;

	public DeepUpdateSplitter(CdsDataStore dataStore) {
		this.dataStore = dataStore;
		this.session = dataStore.getSessionContext();
	}

	public EntityOperations computeOperations(CdsEntity targetEntity, CqnUpdate update,
			Map<String, Object> targetKeys) {
		entity = targetEntity;
		operations = new EntityOperations();
		operations.entries(update.entries());
		StructuredType<?> path = targetRef(update);
		List<Map<String, Object>> updateEntries = determineUpdateEntries(targetKeys, path);
		for (Map<String, Object> entry : updateEntries) {
			EntityKeys entityId = keys(entity, entry);
			operations.add(root(entityId, session).update(entry, emptyMap()));
		}
		if (!updateEntries.isEmpty()) { // compute ops for child entities
			entity.associations().forEach(assoc -> cascade(path, assoc, updateEntries));
		}

		return operations;
	}

	private List<Map<String, Object>> determineUpdateEntries(Map<String, Object> targetKeys, StructuredType<?> path) {
		Set<String> keyElements = keyNames(entity);
		if (!hasInfixFilter(path.asRef())) {
			operations.entries().forEach(e -> addKeyValues(entity, targetKeys, e));
			if (entriesContainValuesFor(keyElements)) {
				return operations.entries();
			}
		}
		Set<String> targetKeyElements = targetKeys.keySet();
		Set<String> missingKeys = Sets.filter(keyElements, k -> !targetKeyElements.contains(k));
		if (!entriesContainValuesFor(missingKeys)) {
			return selectKeyValues(path, missingKeys);
		}
		return evaluateFilter(path, targetKeys, keyElements);
	}

	private boolean entriesContainValuesFor(Set<String> elements) {
		return operations.entries().stream().allMatch(e -> e.keySet().containsAll(elements));
	}

	private List<Map<String, Object>> selectKeyValues(StructuredType<?> path, Set<String> missingKeys) {
		if (operations.entries().size() == 1) {
			logger.warn("Update data is missing key values of entity {}. Executing query to determine key values.",
					entity.getQualifiedName());
			CqnSelect select = Select.from(path).columns(missingKeys.stream().map(CQL::get));
			Result result = dataStore.execute(select);
			Map<String, Object> singleEntry = operations.entries().get(0);
			List<Map<String, Object>> updateEntries = new ArrayList<>();
			if (result.rowCount() == 1) {
				Row keyValues = result.single();
				singleEntry.putAll(keyValues);
				updateEntries.add(singleEntry);
			} else if (result.rowCount() > 1) { // searched deep update
				result.forEach(row -> row.putAll(DataUtils.copyMap(singleEntry)));
				updateEntries.addAll(result.list());
			}
			return updateEntries;
		}
		throw new CdsDataException("Update data is missing key values " + missingKeys + " of entity " + entity);
	}

	private List<Map<String, Object>> evaluateFilter(StructuredType<?> path, Map<String, Object> targetKeys,
			Set<String> keyElements) {
		logger.debug("Executing query to evaluate update filter condition {}", path);
		if (!targetKeys.isEmpty()) {
			// add key values from infix filter to each data entry
			operations.entries().forEach(e -> addKeyValues(entity, targetKeys, e));
		}
		List<Map<String, Object>> updateEntries = new ArrayList<>();
		Map<Integer, Row> entityKeys = selectKeysMatchingPathFilter(path, keyElements, operations.entries());
		Map<Integer, Map<String, Object>> updateData = operations.entries().stream()
				.collect(toMap(row -> hash(row, keyElements), Function.identity()));
		entityKeys.forEach((hash, key) -> {
			updateEntries.add(updateData.get(hash));
		});
		logger.debug("Update filter condition fulfilled by {} entities", entityKeys.size());
		return updateEntries;
	}

	private static void addKeyValues(CdsEntity entity, Map<String, Object> keyValues, Map<String, Object> target) {
		keyValues.forEach((key, valInRef) -> {
			Object valInData = target.put(key, valInRef);
			if (valInData != null && !valInData.equals(valInRef)) {
				throw new CdsDataException("Values for key element '" + key
						+ "' in update data do not match values in update ref or where clause");
			}
		});
		target.putAll(Maps.filterValues(DataUtils.keyValues(entity, target), v -> !Objects.isNull(v)));
	}

	private static StructuredType<?> targetRef(CqnUpdate update) {
		if (!CqnStatementUtils.containsPathExpression(update.where())) {
			return CqnStatementUtils.targetRef(update);
		}
		throw new UnsupportedOperationException("Deep updates with path in where clause are not supported");
	}

	private Map<Integer, Row> selectKeysMatchingPathFilter(StructuredType<?> path, Set<String> keys,
			Iterable<Map<String, Object>> keyValueSets) {
		CqnSelect select = Select.from(path).columns(keys.stream().map(CQL::get)).byParams(keys);
		Result entries = dataStore.execute(select, keyValueSets);

		return entries.stream().collect(toMap(row -> hash(row, keys), Function.identity()));
	}

	private void cascade(StructuredType<?> path, CdsElement assoc, List<Map<String, Object>> entries) {
		Iterable<Map<String, Object>> updateEntries = Iterables.filter(entries, e -> e.containsKey(assoc.getName()));
		if (!Iterables.isEmpty(updateEntries)) {
			CdsEntity targetEntity = assoc.getType().as(CdsAssociationType.class).getTarget();
			if (isSingleValued(assoc.getType())) {
				toOne(path, assoc, targetEntity, updateEntries);
			} else {
				toMany(path, assoc, targetEntity, updateEntries);
			}
		}
	}

	/*
	 * Computes operations for to-one associations / compositions
	 * - remove / delete if association is mapped to null
	 * - insert / update if association is mapped to data map
	 */
	private void toOne(StructuredType<?> path, CdsElement assoc, CdsEntity targetEntity,
			Iterable<Map<String, Object>> parentEntries) {

		CdsEntity parentEntity = assoc.getDeclaringType();
		boolean forward = !isReverseAssociation(assoc);
		boolean composition = assoc.getType().as(CdsAssociationType.class).isComposition();
		boolean forwardAssoc = !composition && forward;
		Set<String> parentKeys = keyNames(parentEntity);
		// execute batch select to determine the association's target key values
		Map<Integer, Row> expandedParentEntries = selectPathExpandToOneTargetKeys(path, assoc, forward, parentKeys,
				parentEntries);

		List<Map<String, Object>> targetUpdateEntries = new ArrayList<>(Iterables.size(parentEntries));
		for (Map<String, Object> parentUpdateEntry : parentEntries) {
			Row parentEntry = expandedParentEntries.getOrDefault(hash(parentUpdateEntry, parentKeys), row(emptyMap()));
			Map<String, Object> targetEntry = getDataFor(assoc, parentEntry);
			Map<String, Object> targetUpdateEntry = getDataFor(assoc, parentUpdateEntry);

			if (targetEntry != null && !targetEntry.isEmpty()) {
				// assoc points to an existing target entity
				if (targetUpdateEntry == null) {
					// clear FKs & delete target when cascading
					remove(assoc, forward, targetEntity, singletonList(targetEntry));

				} else if (targetChange(assoc, forwardAssoc, targetEntity, targetEntry, targetUpdateEntry)) {
					if (isCascading(CascadeType.INSERT, assoc) || isCascading(CascadeType.UPDATE, assoc)) {
						// composition -> insert / assoc -> update or insert
						operations.add(
								xsert(parentUpdateEntry, assoc, forward, targetEntity, targetUpdateEntry, composition));
						targetUpdateEntries.add(targetUpdateEntry);
					} else { // update FK (relaxed data)
						removeNonFkValues(assoc, targetUpdateEntry);
					}
					if (composition) { // delete old target
						delete(targetEntity, targetEntry);
					}

				} else { // update target if cascading
					EntityOperation op = update(targetEntity, targetEntry, targetUpdateEntry, emptyMap());
					if (isCascading(CascadeType.UPDATE, assoc)) {
						operations.add(op);
						targetUpdateEntries.add(targetUpdateEntry);
					}
				}

			} else if (targetUpdateEntry != null) { // assoc is null
				if (forwardAssoc && !isCascading(CascadeType.INSERT, assoc)) {
					// set ref to existing target
					removeNonFkValues(assoc, targetUpdateEntry);
				} else { // insert or update target
					boolean generatedKey = generateUuidKeys(targetEntity, targetUpdateEntry);
					if (forward && generatedKey) { // update parent with target ref (FK) values
						Map<String, Object> parentKeyValues = keys(parentEntity, parentUpdateEntry);
						Map<String, Object> targetRefValues = fkValues(assoc, !forward, targetUpdateEntry);
						operations.add(update(parentEntity, parentKeyValues, new HashMap<>(), targetRefValues));
					}
					boolean insertOnly = composition || generatedKey;
					operations
							.add(xsert(parentUpdateEntry, assoc, forward, targetEntity, targetUpdateEntry, insertOnly));
					targetUpdateEntries.add(targetUpdateEntry);
				}
			}
			// cascade over associations
			targetEntity.associations().forEach(a -> cascade(path.to(assoc.getName()), a, targetUpdateEntries));
		}
	}

	private static boolean targetChange(CdsElement assoc, boolean forward, CdsEntity targetEntity,
			Map<String, Object> oldEntry, Map<String, Object> newEntry) {
		Set<String> refElements = forward ? refElements(assoc, forward) : keyNames(targetEntity);
		return refElements.stream().anyMatch(k -> {
			Object newVal = newEntry.get(k);
			return newVal != null && !newVal.equals(oldEntry.get(k));
		});
	}

	private static void removeNonFkValues(CdsElement assoc, Map<String, Object> data) {
		Set<String> assocKeys = refElements(assoc, true);
		data.keySet().retainAll(assocKeys);
	}

	private Map<Integer, Row> selectPathExpandToOneTargetKeys(StructuredType<?> path, CdsElement assoc,
			boolean forwardMapped, Set<String> parentKeys, Iterable<Map<String, Object>> keyValueSets) {
		logger.debug("Executing query to determine target entity of {}", assoc.getQualifiedName());
		List<CqnSelectListItem> slis = parentKeys.stream().map(CQL::get).collect(toList());
		Set<String> refElements = refElements(assoc, forwardMapped);
		CdsModelUtils.targetKeys(assoc).forEach(refElements::add);
		slis.add(CQL.to(assoc.getName()).expand(refElements.toArray(new String[refElements.size()])));
		CqnSelect select = Select.from(path).columns(slis).byParams(parentKeys);
		Result targetEntries = dataStore.execute(select, keyValueSets);

		return targetEntries.stream().collect(toMap(row -> hash(row, parentKeys), Function.identity()));
	}

	private Map<String, Object> fkValues(CdsElement assoc, boolean reverseMapped, Map<String, Object> data) {
		return new OnConditionAnalyzer(assoc, reverseMapped, session).getFkValues(data);
	}

	/*
	 * Computes operations for to-many associations / compositions
	 * - insert / update entities that are in the update data list
	 * - remove / delete entities that are not in the update data list
	 */
	private void toMany(StructuredType<?> path, CdsElement assoc, CdsEntity targetEntity,
			Iterable<Map<String, Object>> parentEntries) {
		boolean composition = assoc.getType().as(CdsAssociationType.class).isComposition();
		Set<String> parentKeys = keyNames(assoc.getDeclaringType());
		Set<String> targetKeys = CdsModelUtils.targetKeys(assoc);
		targetKeys.remove(DraftAdapter.IS_ACTIVE_ENTITY);
		// execute batch select to determine the association target entities
		// for deletion by keys if they are not included in the update data list
		Map<Integer, Row> targetEntries = selectTargetEntries(assoc, parentKeys, targetKeys, parentEntries);

		List<Map<String, Object>> updateEntries = new ArrayList<>(Iterables.size(parentEntries));
		for (Map<String, Object> parentEntry : parentEntries) {
			List<Map<String, Object>> targetUpdateEntries = getDataFor(assoc, parentEntry);
			if (targetUpdateEntries == null) {
				throw new CdsDataException("Value for to-many association '" + assoc.getDeclaringType() + "." + assoc
						+ "' must not be null.");
			}
			Map<String, Object> parentRefValues = new OnConditionAnalyzer(assoc, true, session)
					.getFkValues(parentEntry);
			if (parentRefValues.containsValue(null)) {
				throw new CdsDataException("Values of ref elements " + parentRefValues.keySet() + " for mapping "
						+ targetEntity + " to " + assoc.getDeclaringType() + " cannot be determined from update data.");
			}
			for (Map<String, Object> updateEntry : targetUpdateEntries) {
				Map<String, Object> updateEntryWithFks = new HashMap<>(updateEntry);
				updateEntryWithFks.putAll(parentRefValues);

				boolean targetPresent = targetEntries.remove(hash(updateEntryWithFks, targetKeys)) != null;
				boolean generatedKey = !targetPresent && generateUuidKeys(targetEntity, updateEntry);
				if (generatedKey) {
					updateEntryWithFks.putAll(updateEntry);
				}
				EntityKeys targetId = keys(targetEntity, updateEntryWithFks);

				EntityOperation operation;
				if (generatedKey) { // insert
					operation = EntityOperation.insert(targetId, updateEntry, parentRefValues, session);
				} else if (targetPresent) { // nop or update
					operation = nop(targetId, session).update(updateEntry, emptyMap());
				} else { // assoc: insert or update / composition: insert
					operation = xsert(assoc, targetEntity, updateEntry, parentRefValues, composition);
				}
				assertCascading(operation, assoc);
				operations.add(operation);
				updateEntries.add(updateEntryWithFks);
			}
		}
		// remove / delete all entities that are not included in the update data list
		remove(assoc, false, targetEntity, targetEntries.values());

		// cascade over associations
		targetEntity.associations().forEach(a -> cascade(path.to(assoc.getName()), a, updateEntries));
	}

	private Map<Integer, Row> selectTargetEntries(CdsElement assoc, Set<String> parentKeys, Set<String> targetKeys,
			Iterable<Map<String, Object>> keyValueSets) {
		logger.debug("Executing query to determine target entity of {}", assoc.getQualifiedName());
		StructuredType<?> path = CQL.entity(assoc.getDeclaringType().getQualifiedName()).filterByParams(parentKeys)
				.to(assoc.getName());
		CqnSelect select = Select.from(path).columns(targetKeys.stream().map(CQL::get));
		Result targetEntries = dataStore.execute(select, keyValueSets);

		return targetEntries.stream().collect(toMap(row -> hash(row, targetKeys), Function.identity()));
	}

	private EntityOperation update(CdsEntity target, Map<String, Object> entry, Map<String, Object> updateData, Map<String, Object> fkValues) {
		EntityOperation op = nop(keys(target, entry), entry, session).update(updateData, fkValues);
		updateData.putAll(op.targetKeys());

		return op;
	}

	private void remove(CdsElement association, boolean forwardMapped, CdsEntity targetEntity,
			Collection<? extends Map<String, Object>> keyValues) {
		if (isCascading(CascadeType.DELETE, association)) {
			keyValues.forEach(k -> delete(targetEntity, k));
		} else if (!forwardMapped) {
			assertCascading(CascadeType.UPDATE, association);
			Set<String> parentRefElements = refElements(association, forwardMapped);
			keyValues.forEach(k -> {
				operations.add(nop(keys(targetEntity, k), session).updateToNull(parentRefElements));
			});
		}
	}

	private void delete(CdsEntity targetEntity, Map<String, Object> data) {
		EntityKeys key = keys(targetEntity, data);
		operations.add(EntityOperation.delete(key, session));
		cascadeDelete(dataStore, key).forEach(operations::add);
	}

	private static Set<String> keyNames(CdsEntity entity) {
		Set<String> keyElements = CdsModelUtils.keyNames(entity);
		keyElements.remove(DraftAdapter.IS_ACTIVE_ENTITY);
		return keyElements;
	}

	private EntityOperation xsert(Map<String, Object> parentData, CdsElement association, boolean forwardMapped, CdsEntity entity,
			Map<String, Object> entityData, boolean insertOnly) {
		Map<String, Object> fkValues = emptyMap();
		if (!forwardMapped) { // set parent ref (FK) in target entity
			fkValues = fkValues(association, !forwardMapped, parentData);
		}
		return xsert(association, entity, entityData, fkValues, insertOnly);
	}

	private EntityOperation xsert(CdsElement association, CdsEntity entity, Map<String, Object> entityData,
			Map<String, Object> fkValues, boolean insertOnly) {
		Map<String, Object> data = new HashMap<>(entityData);
		data.putAll(fkValues);
		EntityKeys targetEntity = keys(entity, data);
		boolean insert = isCascading(CascadeType.INSERT, association);
		boolean update = !insertOnly && isCascading(CascadeType.UPDATE, association);
		if (insert && update && containsStream(entityData)) {
			// InputStreams can only be consumed once
			// Determine if target exists to decide between insert and update
			CqnSelect select = Select.from(entity).columns(CQL.plain("1").as("1")).matching(targetEntity);
			update = dataStore.execute(select).rowCount() > 0;
			insert = !update;
		}
		if (insert && update) {
			return EntityOperation.upsert(targetEntity, entityData, fkValues, session);
		}
		if (insert) {
			return EntityOperation.insert(targetEntity, entityData, fkValues, session);
		}
		if (update) {
			return nop(targetEntity, session).update(entityData, fkValues);
		}
		if (CdsModelUtils.managedToOne(association.getType()) && isFkUpdate(association, entityData, session)) {
			return nop(targetEntity, session);
		}
		CdsEntity target = association.getType().as(CdsAssociationType.class).getTarget();
		throw new CdsDataException(String.format(
				"UPSERT entity '%s' via association '%s.%s' is not allowed. The association does not cascade insert or update.",
				target, association.getDeclaringType(), association));
	}

	@SuppressWarnings({ "rawtypes", "unchecked" })
	private boolean containsStream(Map<String, Object> entityData) {
		return entityData.values().stream()
				.anyMatch(v -> v instanceof InputStream || v instanceof Map && containsStream((Map) v));
	}

	private static void assertCascading(EntityOperation op, CdsElement association) {
		if (op.operation() == Operation.UPDATE) {
			assertCascading(CascadeType.UPDATE, association);
		} else if (op.operation() == Operation.INSERT) {
			assertCascading(CascadeType.INSERT, association);
		} else if (op.operation() == Operation.UPSERT) {
			assertCascading(CascadeType.UPDATE, association);
			assertCascading(CascadeType.INSERT, association);
		} else if (op.operation() == Operation.DELETE) {
			assertCascading(CascadeType.DELETE, association);
		}
	}

	private static void assertCascading(CascadeType cascadeType, CdsElement association) {
		if (!isCascading(cascadeType, association)) {
			CdsEntity target = association.getType().as(CdsAssociationType.class).getTarget();
			throw new CdsDataException(String.format(
					"%s entity '%s' via association '%s.%s' is not allowed. The association does not cascade %s.",
					cascadeType.name(), target, association.getDeclaringType(), association, cascadeType));
		}
	}

	private static Set<String> refElements(CdsElement assoc, boolean forwardMapped) {
		HashMap<String, String> mapping = new HashMap<>();
		new OnConditionAnalyzer(assoc, !forwardMapped).getFkMapping().forEach((fk, val) -> {
			if (val.isRef() && !val.asRef().firstSegment().startsWith("$")) {
				mapping.put(fk, val.asRef().lastSegment());
			}
		});
		if (forwardMapped) {
			return new HashSet<>(mapping.values());
		}
		return new HashSet<>(mapping.keySet());
	}

	public static <T> T getDataFor(CdsElement assoc, Map<String, Object> data) {
		return DataUtils.getOrDefault(data, assoc.getName(), null);
	}

}
