/************************************************************************
 * © 2019-2022 SAP SE or an SAP affiliate company. All rights reserved. *
 ************************************************************************/
package com.sap.cds.impl;

import static com.sap.cds.impl.AssociationAnalyzer.refElements;
import static com.sap.cds.impl.EntityCascader.EntityKeys.keys;
import static com.sap.cds.impl.builder.model.ExpressionImpl.matching;
import static com.sap.cds.util.CdsModelUtils.isCascading;
import static com.sap.cds.util.CdsModelUtils.isReverseAssociation;
import static com.sap.cds.util.CdsModelUtils.isSingleValued;
import static java.util.Collections.emptyMap;
import static java.util.Collections.unmodifiableSet;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiPredicate;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
import com.sap.cds.CdsDataStore;
import com.sap.cds.Result;
import com.sap.cds.SessionContext;
import com.sap.cds.impl.EntityCascader.EntityOperation.Operation;
import com.sap.cds.impl.builder.model.ElementRefImpl;
import com.sap.cds.ql.CQL;
import com.sap.cds.ql.CdsDataException;
import com.sap.cds.ql.ElementRef;
import com.sap.cds.ql.Select;
import com.sap.cds.ql.StructuredType;
import com.sap.cds.ql.cqn.CqnPredicate;
import com.sap.cds.ql.cqn.CqnSelect;
import com.sap.cds.ql.cqn.CqnSelectListItem;
import com.sap.cds.ql.impl.SelectBuilder;
import com.sap.cds.reflect.CdsAssociationType;
import com.sap.cds.reflect.CdsElement;
import com.sap.cds.reflect.CdsEntity;
import com.sap.cds.reflect.impl.DraftAdapter;
import com.sap.cds.util.CdsModelUtils.CascadeType;
import com.sap.cds.util.DataUtils;
import com.sap.cds.util.OnConditionAnalyzer;

public class EntityCascader {

	private final CdsDataStore dataStore;
	private final CdsEntity rootEntity;
	private final Map<String, Object> paramValues = new HashMap<>();
	private final Set<EntityKeys> visited = new HashSet<>();
	private CqnPredicate rootFilter;

	private EntityCascader(CdsDataStore dataStore, CdsEntity rootEntity) {
		this.dataStore = dataStore;
		this.rootEntity = rootEntity;
	}

	public static EntityCascader from(CdsDataStore dataStore, CdsEntity entity) {
		return new EntityCascader(dataStore, entity);
	}

	public EntityCascader where(Optional<CqnPredicate> pred) {
		return where(pred.orElse(null));
	}

	public EntityCascader where(CqnPredicate pred) {
		this.rootFilter = pred;
		return this;
	}

	public EntityCascader with(Map<String, Object> paramValues) {
		this.paramValues.clear();
		this.paramValues.putAll(paramValues);
		return this;
	}

	public Set<EntityKeys> cascade(CascadeType cascadeType) {
		cascadeRoot((d, a) -> isCascading(cascadeType, a), emptyMap());

		return unmodifiableSet(visited);
	}

	@VisibleForTesting
	public Set<EntityKeys> cascade(Predicate<CdsAssociationType> assocFilter) {
		cascadeRoot((d, a) -> assocFilter.test(a.getType()), emptyMap());

		return unmodifiableSet(visited);
	}

	public static Stream<EntityOperation> cascadeDelete(CdsDataStore dataStore, EntityKeys targetEntity) {
		return EntityCascader.from(dataStore, targetEntity.entity).where(matching(targetEntity))
				.cascade(CascadeType.DELETE).stream()
				.map(k -> EntityOperation.delete(k, dataStore.getSessionContext()));
	}

	private void cascadeRoot(BiPredicate<Map<String, Object>, CdsElement> assocFilter, Map<String, Object> data) {
		cascade(rootEntity, rootFilter, paramValues, assocFilter, data);
	}

	private void cascade(CdsEntity entity, CqnPredicate filter, Map<String, Object> paramValues,
			BiPredicate<Map<String, Object>, CdsElement> assocFilter, Map<String, Object> data) {
		entity.associations().filter(a -> assocFilter.test(data, a))
				.forEach(assoc -> cascade(entity, filter, paramValues, assocFilter, assoc, data));
	}

	private void cascade(CdsEntity entity, CqnPredicate filter, Map<String, Object> paramValues,
			BiPredicate<Map<String, Object>, CdsElement> assocFilter, CdsElement association,
			Map<String, Object> data) {

		String assocName = association.getName();
		CdsAssociationType assocType = association.getType();
		StructuredType<?> path = CQL.entity(entity.getQualifiedName()).filter(filter).to(assocName);
		CdsEntity targetEntity = assocType.getTarget();
		if (isSingleValued(assocType)) {
			cascadeToOne(path, entity, targetEntity, association, paramValues, assocFilter, data);
		} else {
			List<Map<String, Object>> d = getList(data, assocName);
			cascadeToMany(path, targetEntity, association, paramValues, assocFilter, d);
		}
	}

	@SuppressWarnings("unchecked")
	private static Map<String, Object> getMap(Map<String, Object> data, String key) {
		return (Map<String, Object>) data.getOrDefault(key, emptyMap());
	}

	@SuppressWarnings("unchecked")
	private static List<Map<String, Object>> getList(Map<String, Object> data, String key) {
		return (List<Map<String, Object>>) data.get(key);
	}

	private void cascadeToOne(StructuredType<?> path, CdsEntity parent, CdsEntity target, CdsElement association,
			Map<String, Object> paramValues, BiPredicate<Map<String, Object>, CdsElement> assocFilter,
			Map<String, Object> parentData) {

		CqnSelect selectTargetKeys = selectTargetKeys(association, path);
		Result targetData = dataStore.execute(selectTargetKeys, paramValues);
		List<EntityKeys> targetKeys = new ArrayList<>();
		Map<String, Object> newData = getMap(parentData, association.getName());
		targetData.stream().map(data -> keys(target, data)).forEach(targetKeys::add);

		targetKeys.stream().filter(this::notVisited)
				.forEach(keys -> cascade(target, matching(keys), emptyMap(), assocFilter, newData));
	}

	private void cascadeToMany(StructuredType<?> path, CdsEntity target, CdsElement association,
			Map<String, Object> paramValues, BiPredicate<Map<String, Object>, CdsElement> assocFilter,
			List<Map<String, Object>> newData) {

		CqnSelect selectTargetKeys = selectTargetKeys(association, path);
		Set<EntityKeys> targetKeys = dataStore.execute(selectTargetKeys, paramValues).stream()
				.map(row -> keys(target, row)).collect(Collectors.toSet());

		targetKeys.stream().filter(this::notVisited)
				.forEach(k -> cascade(target, matching(k), emptyMap(), assocFilter, emptyMap()));
	}

	private Select<?> selectTargetKeys(CdsElement association, StructuredType<?> path) {
		CdsAssociationType assoc = association.getType();
		return SelectBuilder.from(path).columns(assoc.getTarget().keyElements().flatMap(EntityCascader::slis));
	}

	private static Stream<CqnSelectListItem> slis(CdsElement keyElement, String... association) {
		String[] path = Arrays.copyOf(association, association.length + 1);
		path[path.length - 1] = keyElement.getName();

		if (keyElement.getType().isAssociation()) {
			return refElements(keyElement).flatMap(k -> slis(k, path));
		}
		return Stream.of(sli(path));
	}

	private static CqnSelectListItem sli(String... path) {
		ElementRef<?> element = ElementRefImpl.element(path);
		if (path.length > 1) {
			String alias = Arrays.stream(path).collect(Collectors.joining("."));
			element = element.as(alias);
		}
		return element;
	}

	private boolean notVisited(EntityKeys keys) {
		return visited.add(keys);
	}

	public static class EntityOperations {
		private final List<EntityOperation> operations = new ArrayList<>();
		private final List<Map<String, Object>> entries = new ArrayList<>();

		public void entries(List<Map<String, Object>> entries) {
			this.entries.addAll(entries);
		}

		public List<Map<String, Object>> entries() {
			return entries;
		}

		public boolean add(EntityOperation op) {
			return operations.add(op);
		}

		public Stream<EntityOperation> rootOps() {
			return operations.stream().filter(EntityOperation::isRootOp);
		}

		public Stream<EntityOperation> filter(Operation opType) {
			return operations.stream().filter(d -> d.operation() == opType);
		}

		public long[] updateCount() {
			long[] updateCount = rootOps().mapToLong(EntityOperation::updateCount).toArray();
			if (updateCount.length == 0) {
				return new long[] { 0 };
			}
			return updateCount;
		}
	}

	public static class EntityOperation extends com.google.common.collect.ForwardingMap<String, Object> {
		private final EntityKeys targetKeys;
		private final Map<String, Object> data = new HashMap<>();
		private final Set<String> updated = new HashSet<>();
		private final boolean root;
		private Map<String, Object> updateData;
		private SessionContext sessionContext;
		private long updateCount = 0;

		private Operation operation;

		private EntityOperation(EntityKeys targetKeys, Operation operation, SessionContext sessionContext, boolean root) {
			this.targetKeys = targetKeys;
			this.data.putAll(targetKeys);
			this.operation = operation;
			this.sessionContext = sessionContext;
			this.root = root;
		}

		public static EntityOperation root(EntityKeys entity, SessionContext sessionContext) {
			return new EntityOperation(entity, Operation.UPDATE, sessionContext, true);
		}

		public static EntityOperation nop(EntityKeys entity, SessionContext sessionContext) {
			return new EntityOperation(entity, Operation.NOP, sessionContext, false);
		}

		public static EntityOperation nop(EntityKeys entity, Map<String, Object> data, SessionContext sessionContext) {
			return nop(entity, sessionContext).data(data, Collections.emptyMap());
		}

		public static EntityOperation upsert(EntityKeys entity, Map<String, Object> data, Map<String, Object> fkValues,
				SessionContext sessionContext) {
			return new EntityOperation(entity, Operation.UPSERT, sessionContext, false).update(data, fkValues);
		}

		public static EntityOperation insert(EntityKeys entity, Map<String, Object> data, Map<String, Object> fkValues,
				SessionContext sessionContext) {
			return new EntityOperation(entity, Operation.INSERT, sessionContext, false).data(data, fkValues);
		}

		public static EntityOperation delete(EntityKeys targetKeys, SessionContext sessionContext) {
			return new EntityOperation(targetKeys, Operation.DELETE, sessionContext, false);
		}

		public CdsEntity targetEntity() {
			return targetKeys.entity;
		}

		public Operation operation() {
			return operation;
		}

		public boolean isRootOp() {
			return root;
		}

		public long updateCount() {
			return updateCount;
		}

		private EntityOperation data(Map<String, Object> data, Map<String, Object> fkValues) {
			this.updateData = data;
			this.data.putAll(flattenData(targetKeys.entity, data));
			this.data.putAll(fkValues);
			return this;
		}

		public boolean inserted(Map<String, Object> data) {
			if (!data.isEmpty()) {
				updateData.putAll(Maps.filterKeys(data, k -> !this.data.keySet().contains(k)));
				this.updateCount = 1;
				return true;
			}
			return false;
		}

		public boolean updated(Map<String, Object> data, long updateCount) {
			this.updateCount = updateCount;
			if (!data.isEmpty()) {
				updateData.putAll(Maps.filterKeys(data, k -> !this.data.keySet().contains(k)));
				return true;
			}
			return false;
		}

		public EntityOperation deleted() {
			return this;
		}

		public EntityOperation updateToNull(Set<String> keys) {
			for (String key : keys) {
				data.put(key, null);
				updated.add(key);
				operation = Operation.UPDATE;
			}
			return this;
		}

		public EntityOperation update(Map<String, Object> updateData, Map<String, Object> fkValues) {
			mergeData(updateData);
			Map<String, Object> flattenedData = flattenData(targetEntity(), updateData);
			flattenedData.putAll(fkValues);
			flattenedData.forEach((key, newValue) -> {
				boolean valuePresent = data.containsKey(key);
				Object oldVal = data.put(key, newValue);
				if (!valuePresent || !Objects.equals(oldVal, newValue)) {
					updated.add(key);
					if (operation == Operation.NOP) {
						operation = Operation.UPDATE;
					}
					Object keyVal = targetKeys.get(key);
					if (keyVal != null && !keyVal.equals(newValue)) {
						throw new CdsDataException("Key values cannot be changed");
					}
				}
			});
			return this;
		}

		private void mergeData(Map<String, Object> data) {
			if (updateData == null) {
				updateData = data;
			} else {
				updateData.putAll(data);
			}
		}

		private Map<String, Object> flattenData(CdsEntity entity, Map<String, Object> updateData) {
			Map<String, Object> copy = new HashMap<>(updateData);
			associationsInData(entity, updateData).forEach(a -> flattenData(a, copy));
			return copy;
		}

		private static Stream<CdsElement> associationsInData(CdsEntity e, Map<String, Object> d) {
			return e.associations().filter(a -> d.keySet().contains(a.getName()));
		}

		private void flattenData(CdsElement association, Map<String, Object> data) {
			if (isReverseAssociation(association)) {
				data.remove(association.getName());
			} else {
				@SuppressWarnings("unchecked")
				Map<String, Object> targetValues = (Map<String, Object>) data.remove(association.getName());
				if (targetValues == null) {
					// set FKs to null
					data.putAll(computeFkValues(association, emptyMap(), sessionContext));
				} else {
					Map<String, Object> fkValues = computeFkValues(association, targetValues, sessionContext);
					if (!fkValues.containsValue(null)) {
						data.putAll(fkValues);
					}
				}
			}
		}

		private static Map<String, Object> computeFkValues(CdsElement association, Map<String, Object> targetValues,
				SessionContext sessionContext) {
			return new OnConditionAnalyzer(association, isReverseAssociation(association), sessionContext)
					.getFkValues(targetValues);
		}

		public EntityKeys targetKeys() {
			return targetKeys;
		}

		public Map<String, Object> updateValues() {
			Map<String, Object> values = new HashMap<>(Maps.filterKeys(data, updated::contains));
			values.putAll(targetKeys);

			return values;
		}

		@Override
		protected Map<String, Object> delegate() {
			return data;
		}

		@Override
		public int hashCode() {
			return Objects.hash(targetKeys, data);
		}

		@Override
		public boolean equals(Object obj) {
			if (this == obj)
				return true;
			if (obj == null) {
				return false;
			}
			if (obj.getClass() != this.getClass())
				return false;
			EntityOperation other = (EntityOperation) obj;
			if (!targetKeys.equals(other.targetKeys))
				return false;
			return data.equals(other.data);
		}

		@Override
		public String toString() {
			return operation + " " + targetKeys + ": " + data.toString();
		}

		public enum Operation {
			NOP, INSERT, UPDATE, UPSERT, DELETE
		}

		public enum State {
			UNCHANGED, INSERTED, UPDATED, DELETED
		}
	}

	public static class EntityKeys extends com.google.common.collect.ForwardingMap<String, Object> {
		private final CdsEntity entity;
		private final Map<String, Object> keys;

		private EntityKeys(CdsEntity entity, Map<String, Object> keys) {
			this.entity = entity;
			this.keys = keys;
		}

		public static EntityKeys keys(CdsEntity entity, Map<String, Object> data) {
			Map<String, Object> keyValues = DataUtils.keyValues(entity, data);
			keyValues.remove(DraftAdapter.IS_ACTIVE_ENTITY);
			if (keyValues.values().contains(null)) {
				throw new CdsDataException("Key values of entity " + entity + " must not be null");
			}

			return new EntityKeys(entity, keyValues);
		}

		public Map<String, Object> keys() {
			return Collections.unmodifiableMap(keys);
		}

		public CdsEntity entity() {
			return entity;
		}

		@Override
		protected Map<String, Object> delegate() {
			return keys;
		}

		@Override
		public int hashCode() {
			return Objects.hash(entity.getQualifiedName(), keys);
		}

		@Override
		public boolean equals(Object obj) {
			if (this == obj)
				return true;
			if (obj == null) {
				return false;
			}
			if (obj.getClass() != this.getClass())
				return false;
			EntityKeys other = (EntityKeys) obj;
			if (!entity.getQualifiedName().equals(other.entity.getQualifiedName()))
				return false;
			return keys.equals(other.keys);
		}

		@Override
		public String toString() {
			return entity.getQualifiedName() + "[" + keys.toString() + "]";
		}
	}

}
