/************************************************************************
 * © 2019-2025 SAP SE or an SAP affiliate company. All rights reserved. *
 ************************************************************************/
package com.sap.cds.impl;

import static com.sap.cds.impl.LazyRowImpl.lazyRow;
import static com.sap.cds.impl.builder.model.ExpressionImpl.matching;
import static com.sap.cds.impl.builder.model.StructuredTypeRefImpl.typeRef;
import static com.sap.cds.impl.parser.token.RefSegmentImpl.refSegment;
import static com.sap.cds.ql.cqn.CqnComparisonPredicate.Operator.EQ;
import static com.sap.cds.util.CdsModelUtils.concreteKeyNames;
import static com.sap.cds.util.CdsModelUtils.isSingleValued;
import static java.util.stream.Collectors.groupingBy;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.sap.cds.CdsDataStore;
import com.sap.cds.CdsDataStoreException;
import com.sap.cds.CdsException;
import com.sap.cds.Result;
import com.sap.cds.Row;
import com.sap.cds.impl.builder.model.ExpandBuilder;
import com.sap.cds.impl.builder.model.ExpressionImpl;
import com.sap.cds.impl.qat.QatBuilder;
import com.sap.cds.jdbc.spi.DbContext;
import com.sap.cds.jdbc.spi.SqlMapping;
import com.sap.cds.ql.CQL;
import com.sap.cds.ql.Select;
import com.sap.cds.ql.cqn.CqnElementRef;
import com.sap.cds.ql.cqn.CqnExpand;
import com.sap.cds.ql.cqn.CqnParameter;
import com.sap.cds.ql.cqn.CqnPredicate;
import com.sap.cds.ql.cqn.CqnReference;
import com.sap.cds.ql.cqn.CqnReference.Segment;
import com.sap.cds.ql.cqn.CqnSelect;
import com.sap.cds.ql.cqn.CqnSelectListItem;
import com.sap.cds.ql.cqn.CqnSelectListValue;
import com.sap.cds.ql.cqn.CqnSortSpecification;
import com.sap.cds.ql.cqn.CqnStructuredTypeRef;
import com.sap.cds.ql.cqn.CqnValue;
import com.sap.cds.ql.impl.ExpandProcessor;
import com.sap.cds.ql.impl.SelectBuilder;
import com.sap.cds.reflect.CdsElement;
import com.sap.cds.reflect.CdsEntity;
import com.sap.cds.reflect.CdsStructuredType;
import com.sap.cds.util.CdsModelUtils;
import com.sap.cds.util.DataUtils;
import com.sap.cds.util.OnConditionAnalyzer;

public class AssociationLoader {

	private static final Logger logger = LoggerFactory.getLogger(AssociationLoader.class);
	private static final String PK_PREFIX = "@@";
	private final CdsDataStore dataStore;
	private final CdsStructuredType root;
	private final List<String> keyNames;
	private final Map<String, String> pk2Element = new HashMap<>();
	private final SqlMapping sqlMapping;

	public AssociationLoader(CdsDataStore dataStore, DbContext dbCtx, CdsStructuredType root) {
		this.dataStore = dataStore;
		this.root = root;
		keyNames = new ArrayList<>(CdsModelUtils.concreteKeyNames(root));
		sqlMapping = dbCtx.getSqlMapping(root);
	}

	public void expand(ExpandProcessor expandProcessor, List<Map<String, Object>> rows) {
		ExpandBuilder<?> expand = expandProcessor.getExpand();
		if (dataStore == null || rows.isEmpty()) {
			return;
		}
		if (logger.isDebugEnabled()) {
			logger.debug("Expand {} using parent-keys", expand.ref());
		}
		boolean lazy = expand.lazy();
		boolean addCount = expand.hasInlineCount();
		CqnStructuredTypeRef ref = expand.ref();
		Map<String, String> mappingAliases = expandProcessor.getMappingAliases();
		if (lazy) {
			expandLazy(expand, ref, mappingAliases, rows, expandProcessor.getQueryHints());
		} else {
			expandEager(expand, ref, mappingAliases, addCount, rows, expandProcessor.isLoadSingle(), expandProcessor.getQueryHints());
		}
	}

	private void expandEager(CqnExpand expand, CqnStructuredTypeRef ref, Map<String, String> mappingAliases, boolean addCount,
			List<Map<String, Object>> rows, boolean enforceLoadSingle, Map<String, Object> queryHints) {
		boolean singleValued = singleValued((CdsEntity) root, expand.ref());
		CqnSelect query = queryByParams(ref, expand.items(), expand.orderBy(), expand.top(), expand.skip(), mappingAliases, queryHints, addCount);
		String path = expand.alias().orElse(ref.lastSegment());

		if (rows.size() == 1 || enforceLoadSingle) {
			rows.forEach(row -> loadSingle(row, query, path, singleValued, addCount));
		} else if (singleValued) {
			loadBulk(rows, query, mappingAliases, (row, list) -> putOne(row, path, list));
		} else if (!expand.hasLimit()) {
			loadBulk(rows, query, mappingAliases, (row, list) -> {
				putMany(row, path, list);
				if (addCount) {
					DataUtils.putPath(row, DataUtils.countName(path), (long) list.size());
				}
			});
		} else {
			rows.forEach(row -> loadSingle(row, query, path, singleValued, addCount));
		}
	}

	private void loadSingle(Map<String, Object> row, CqnSelect query, String path, boolean singleValued,
			boolean addCount) {
		Result result = dataStore.execute(query, row);
		if (singleValued) {
			Row map = result.first().orElse(null);
			DataUtils.putPath(row, path, map, map != null);
		} else {
			List<Row> list = result.list();
			putMany(row, path, list);
			if (addCount) {
				DataUtils.putPath(row, DataUtils.countName(path), result.inlineCount());
			}
		}
	}

	private static interface Mapper {
		void map(Map<String, Object> row, List<Row> list);
	}

	private void loadBulk(List<Map<String, Object>> rows, CqnSelect query, Map<String, String> mappingAliases, Mapper mapper) {
		Map<List<Object>, List<Row>> keyToData = execAndGroupByParentKeys(rows, query);

		for (Map<String, Object> row : rows) {
			List<Object> key = keyValuesInResultData(row, mappingAliases);
			List<Row> list = keyToData.getOrDefault(key, Collections.emptyList());
			list.forEach(r -> clean(r));
			mapper.map(row, list);
		}
	}

	private void putOne(Map<String, Object> row, String path, List<Row> list) {
		Map<String, Object> map = switch (list.size()) {
			case 0 -> null;
			case 1 -> clean(list.get(0));
			default -> throw new CdsDataStoreException("Failed to map result of expand " + path);
		};
		DataUtils.putPath(row, path, map, map != null);
	}

	private void putMany(Map<String, Object> row, String path, List<? extends Map<String, Object>> list) {
		DataUtils.putPath(row, path, list, !list.isEmpty());
	}

	private Map<List<Object>, List<Row>> execAndGroupByParentKeys(List<Map<String, Object>> rows, CqnSelect query) {
		keyNames.forEach( // TODO: Adds the parent keys: T0.ID as '@@ID'
				k -> {
					String el = pk2Element.get(k); // root PK in path mode (JOIN), FK in direct mode
					String columnName = sqlMapping.columnName(el);
					CqnSelectListValue slv = CQL.plain(QatBuilder.ROOT_ALIAS + "." + columnName).as(pkAlias(k));
					((SelectBuilder<?>) query).addItem(slv);
				});
		Result result;
		if (query.orderBy().isEmpty()) {
			result = dataStore.execute(query, rows);
		} else {
			result = dataStore.execute(query, rows, rows.size());
		}

		return result.stream().collect(groupingBy(this::keyValuesInResultRow));
	}

	private static String pkAlias(String k) {
		return PK_PREFIX + k;
	}

	private List<Object> keyValuesInResultRow(Map<String, Object> row) {
		return keyNames.stream().map(k -> row.get(pkAlias(k))).toList();
	}

	private List<Object> keyValuesInResultData(Map<String, Object> row, Map<String, String> aliases) {
		return keyNames.stream().map(k -> {
			var alias = aliases.getOrDefault(k, k);
			return DataUtils.getPathOrElseThrow(row, alias,
					() -> new CdsDataStoreException("Failed to map expand result to parent rows. Missing value for " + alias));
		}).toList();
	}

	private static Map<String, Object> clean(Map<String, Object> r) {
		r.keySet().removeIf(k -> k.startsWith(PK_PREFIX));
		return r;
	}

	private void expandLazy(CqnExpand expand, CqnStructuredTypeRef ref, Map<String, String> mappingAliases,
			List<Map<String, Object>> rows, Map<String, Object> queryHints) {
		for (Map<String, Object> row : rows) {
			injector(row, mappingAliases, queryHints).injectInto(row, ref, expand.items(), expand.orderBy(), expand.top(),
					expand.skip(), expand.alias());
		}
	}

	private LazyAssociationLoaderInjector injector(Map<String, Object> row, Map<String, String> mappingAliases, Map<String, Object> queryHints) {
		Map<String, Object> pkValues = new HashMap<>();
		root.keyElements().forEach(k -> {
			String keyName = k.getName();
			String displayName = mappingAliases.getOrDefault(keyName, keyName);
			pkValues.put(keyName, row.get(displayName));
		});

		return new LazyAssociationLoaderInjector((CdsEntity) root, pkValues, queryHints);
	}

	private class LazyAssociationLoaderInjector {
		private CdsEntity entity;
		private Map<String, Object> keyValues;
		private Map<String, Object> queryHints;

		LazyAssociationLoaderInjector(CdsEntity entity, Map<String, Object> keyValues, Map<String, Object> queryHints) {
			this.entity = entity;
			this.keyValues = keyValues;
			this.queryHints = queryHints;
		}

		private void injectInto(Map<String, Object> row, CqnStructuredTypeRef path, List<CqnSelectListItem> slis,
				List<CqnSortSpecification> orderBy, long top, long skip, Optional<String> alias) {
			CqnSelect query = queryByValues(path, slis, orderBy, top, skip, keyValues);
			Lazy loader = singleValued(entity, path) ? lazyRow(dataStore, query) : new LazyResultImpl(dataStore, query);
			String displayName = alias.orElse(path.lastSegment());
			DataUtils.putPath(row, displayName, loader);
		}

		private CqnSelect queryByValues(CqnStructuredTypeRef path, List<CqnSelectListItem> slis,
				List<CqnSortSpecification> orderBy, long top, long skip, Map<String, Object> keyValues) {
			if (!keyValues.keySet().containsAll(concreteKeyNames(root))) {
				throw new CdsException("Missing key values for entity " + root.getQualifiedName()
						+ ". Please add all keys to the projection.");
			}
			List<CqnReference.Segment> segments = new ArrayList<>();
			segments.add(refSegment(root.getQualifiedName(), matching(keyValues)));
			segments.addAll(path.segments());

			return SelectBuilder.from(typeRef(segments, null)).columns(slis).orderBy(orderBy).limit(top, skip).hints(queryHints);
		}
	}

	private boolean singleValued(CdsEntity entity, CqnStructuredTypeRef path) {
		CdsEntity e = entity;
		CdsElement association = null;
		for (Segment seg : path.segments()) {
			String assocName = seg.id();
			association = e.getAssociation(assocName);
			e = e.getTargetOf(assocName);
		}
		if (association == null) {
			throw new CdsException(
					"Missing association for Entity " + e.getName() + ", under Path " + path.toJson() + ".");
		}
		return isSingleValued(association.getType());
	}

	private CqnSelect queryByParams(CqnStructuredTypeRef path, List<CqnSelectListItem> items,
			List<CqnSortSpecification> orderBy, long top, long skip, Map<String, String> aliases, Map<String, Object> queryHints, boolean addCount) {
		var target = fkMapping(path)
			.map(m -> filteredTarget(m, path, aliases))    // no inner join
			.orElseGet(() -> pathFromRoot(path, aliases)); // inner join

		Select<?> query = Select.from(target).columns(items).orderBy(orderBy).limit(top, skip).hints(queryHints);
		if (addCount) {
			query.inlineCount();
		}
		return query;
	}

	private Optional<Map<String, CqnValue>> fkMapping(CqnStructuredTypeRef path) {
		if (path.size() != 1) {
			return Optional.empty();
		}
		CdsElement assoc = CdsModelUtils.element(root, path.segments());
		if (!CdsModelUtils.isReverseAssociation(assoc)) {
			return Optional.empty();
		}
		OnConditionAnalyzer analyzer = new OnConditionAnalyzer(assoc, true);
		Map<String, CqnValue> fkMapping;
		try {
			fkMapping = analyzer.getFkMapping();
		} catch (UnsupportedOperationException ex) {
			return Optional.empty(); // on-condition contains OR or NOT
		}
		Map<String, CqnValue> fk2Val = new HashMap<>(fkMapping.size());
		Set<String> referencedPks = new HashSet<>();
		fkMapping.forEach((fk, val) -> {
			if (val.isRef() && val.asRef().size() == 1) {
				referencedPks.add(val.asRef().firstSegment());
			}
			fk2Val.put(fk, val);
		});
		if (referencedPks.size() != keyNames.size()) {
			return Optional.empty();
		}
		if (keyNames.stream().anyMatch(k -> !referencedPks.contains(k))) {
			return Optional.empty();
		}

		return Optional.of(fk2Val);
	}

	private CqnStructuredTypeRef pathFromRoot(CqnStructuredTypeRef path, Map<String, String> aliases) {
		List<CqnReference.Segment> segments = new ArrayList<>(path.segments().size() + 1);
		pk2Element.clear();
		keyNames.forEach(k -> pk2Element.put(k, k));
		segments.add(refSegment(root.getQualifiedName(), pkFilter(aliases)));
		segments.addAll(path.segments());

		return typeRef(segments);
	}

	private CqnStructuredTypeRef filteredTarget(Map<String, CqnValue> fkMapping, CqnStructuredTypeRef ref, Map<String, String> aliases) {
		Map<String, String> fkAliases = new HashMap<>();
		Map<String, CqnValue> valFilter = new HashMap<>();
		pk2Element.clear();
		fkMapping.forEach((fk, val) -> {
			if (val.isRef()) {
				String paramName = aliases.get(val.asRef().path());
				fkAliases.put(fk, paramName);
				pk2Element.put(val.asRef().path(), fk);
			} else {
				valFilter.put(fk, val);
			}
		});
		CdsStructuredType targetEntity = CdsModelUtils.target(root, ref.segments());
		CqnPredicate filter = CQL.and(pkFilter(fkAliases), ref.rootSegment().filter().orElse(CQL.TRUE));
		if (!valFilter.isEmpty()) { // additional value e.g. literals filter
			filter = CQL.and(filter, ExpressionImpl.matching(valFilter));
		}

		return typeRef(refSegment(targetEntity.getQualifiedName(), filter));
	}

	private CqnPredicate pkFilter(Map<String, String> aliases) {
		return switch (keyNames.size()) {
			case 0 -> CQL.TRUE;
			case 1 -> pkFilterSingle(aliases);
			default -> pkFilterList(aliases);
		};
	}

	private CqnPredicate pkFilterSingle(Map<String, String> pkAliases) {
		String pk = keyNames.get(0);
		String el = pk2Element.get(pk);

		return CQL.comparison(CQL.get(el), EQ, CQL.param(pkAliases.getOrDefault(el, el)));
	}

	private CqnPredicate pkFilterList(Map<String, String> aliases) {
		int n = keyNames.size();
		List<CqnElementRef> fkRefs = new ArrayList<>(n);
		List<CqnParameter> pkParams = new ArrayList<>(n);
		keyNames.forEach(pk -> {
			String el = pk2Element.get(pk);
			fkRefs.add(CQL.get(el));
			pkParams.add(CQL.param(aliases.getOrDefault(el, el)));
		});

		return CQL.comparison(CQL.list(fkRefs), EQ, CQL.list(pkParams));
	}

}
