/************************************************************************
 * © 2019-2023 SAP SE or an SAP affiliate company. All rights reserved. *
 ************************************************************************/
package com.sap.cds.impl;

import static com.sap.cds.impl.LazyRowImpl.lazyRow;
import static com.sap.cds.impl.builder.model.ExpressionImpl.matching;
import static com.sap.cds.impl.builder.model.StructuredTypeRefImpl.typeRef;
import static com.sap.cds.impl.parser.token.RefSegmentImpl.refSegment;
import static com.sap.cds.ql.cqn.CqnComparisonPredicate.Operator.EQ;
import static com.sap.cds.util.CdsModelUtils.concreteKeyNames;
import static com.sap.cds.util.CdsModelUtils.isSingleValued;
import static java.util.stream.Collectors.groupingBy;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.sap.cds.CdsDataStore;
import com.sap.cds.CdsDataStoreException;
import com.sap.cds.CdsException;
import com.sap.cds.Result;
import com.sap.cds.Row;
import com.sap.cds.impl.builder.model.ExpandBuilder;
import com.sap.cds.impl.qat.QatBuilder;
import com.sap.cds.jdbc.spi.DbContext;
import com.sap.cds.jdbc.spi.SqlMapping;
import com.sap.cds.ql.CQL;
import com.sap.cds.ql.cqn.CqnElementRef;
import com.sap.cds.ql.cqn.CqnExpand;
import com.sap.cds.ql.cqn.CqnParameter;
import com.sap.cds.ql.cqn.CqnPredicate;
import com.sap.cds.ql.cqn.CqnReference;
import com.sap.cds.ql.cqn.CqnReference.Segment;
import com.sap.cds.ql.cqn.CqnSelect;
import com.sap.cds.ql.cqn.CqnSelectListItem;
import com.sap.cds.ql.cqn.CqnSelectListValue;
import com.sap.cds.ql.cqn.CqnSortSpecification;
import com.sap.cds.ql.cqn.CqnStructuredTypeRef;
import com.sap.cds.ql.impl.ExpandProcessor;
import com.sap.cds.ql.impl.SelectBuilder;
import com.sap.cds.reflect.CdsElement;
import com.sap.cds.reflect.CdsEntity;
import com.sap.cds.reflect.CdsStructuredType;
import com.sap.cds.util.CdsModelUtils;
import com.sap.cds.util.DataUtils;

public class AssociationLoader {

	private static final Logger logger = LoggerFactory.getLogger(AssociationLoader.class);
	private static final String PK_PREFIX = "@@";
	private final CdsDataStore dataStore;
	private final CdsStructuredType root;
	private final List<String> keyNames;
	private final SqlMapping sqlMapping;

	public AssociationLoader(CdsDataStore dataStore, DbContext dbCtx, CdsStructuredType root) {
		this.dataStore = dataStore;
		this.root = root;
		keyNames = new ArrayList<>(CdsModelUtils.concreteKeyNames(root));
		sqlMapping = dbCtx.getSqlMapping(root);
	}

	public void expand(ExpandProcessor expandProcessor, List<Map<String, Object>> rows) {
		ExpandBuilder<?> expand = expandProcessor.getExpand();
		if (dataStore == null || rows.isEmpty()) {
			return;
		}
		if (logger.isDebugEnabled()) {
			logger.debug("Expand {} using parent-keys", expand.ref());
		}
		boolean lazy = expand.lazy();
		boolean addCount = expand.hasInlineCount() && !expand.hasLimit();
		CqnStructuredTypeRef ref = expand.ref();
		Map<String, String> mappingAliases = expandProcessor.getMappingAliases();
		if (lazy) {
			expandLazy(expand, ref, mappingAliases, rows);
		} else {
			expandEager(expand, ref, mappingAliases, addCount, rows, expandProcessor.isLoadSingle());
		}
	}

	private void expandEager(CqnExpand expand, CqnStructuredTypeRef ref, Map<String, String> mappingAliases, boolean addCount,
			List<Map<String, Object>> rows, boolean enforceLoadSingle) {
		boolean singleValued = singleValued((CdsEntity) root, expand.ref());
		CqnSelect query = queryByParams(ref, expand.items(), expand.orderBy(), expand.top(), expand.skip(), mappingAliases);
		String path = expand.alias().orElse(ref.lastSegment());

		if (rows.size() == 1 || enforceLoadSingle) {
			rows.forEach(row -> loadSingle(row, query, path, singleValued, addCount));
		} else if (singleValued) {
			loadBulk(rows, query, mappingAliases, (row, list) -> putOne(row, path, list));
		} else if (!expand.hasLimit()) {
			loadBulk(rows, query, mappingAliases, (row, list) -> putMany(row, path, list, addCount));
		} else {
			rows.forEach(row -> loadSingle(row, query, path, singleValued, addCount));
		}
	}

	private void loadSingle(Map<String, Object> row, CqnSelect query, String path, boolean singleValued,
			boolean addCount) {
		Result result = dataStore.execute(query, row);
		if (singleValued) {
			Row map = result.first().orElse(null);
			DataUtils.putPath(row, path, map, map != null);
		} else {
			List<Row> list = result.list();
			putMany(row, path, list, addCount);
		}
	}

	private static interface Mapper {
		void map(Map<String, Object> row, List<Row> list);
	}

	private void loadBulk(List<Map<String, Object>> rows, CqnSelect query, Map<String, String> mappingAliases, Mapper mapper) {
		Map<List<Object>, List<Row>> keyToData = execAndGroupByParentKeys(rows, query);

		for (Map<String, Object> row : rows) {
			List<Object> key = keyValuesInResultData(row, mappingAliases);
			List<Row> list = keyToData.getOrDefault(key, Collections.emptyList());
			list.forEach(r -> clean(r));
			mapper.map(row, list);
		}
	}

	private void putOne(Map<String, Object> row, String path, List<Row> list) {
		Map<String, Object> map = switch (list.size()) {
			case 0 -> null;
			case 1 -> clean(list.get(0));
			default -> throw new CdsDataStoreException("Failed to map result of expand " + path);
		};
		DataUtils.putPath(row, path, map, map != null);
	}

	private void putMany(Map<String, Object> row, String path, List<? extends Map<String, Object>> list,
			boolean addCount) {
		DataUtils.putPath(row, path, list, !list.isEmpty());
		if (addCount) {
			DataUtils.putPath(row, DataUtils.countName(path), Long.valueOf(list.size()));
		}
	}

	private Map<List<Object>, List<Row>> execAndGroupByParentKeys(List<Map<String, Object>> rows, CqnSelect query) {
		keyNames.forEach( // TODO: Adds the parent keys: T0.ID as '@@ID'
				k -> {
					String columnName = sqlMapping.columnName(k);
					CqnSelectListValue slv = CQL.plain(QatBuilder.ROOT_ALIAS + "." + columnName).as(PK_PREFIX + k);
					((SelectBuilder<?>) query).addItem(slv);
				});
		Result result = dataStore.execute(query, rows);

		return result.stream().collect(groupingBy(this::keyValuesInResultRow));
	}

	private List<Object> keyValuesInResultRow(Map<String, Object> row) {
		return keyNames.stream().map(k -> row.get(PK_PREFIX + k)).toList();
	}

	private List<Object> keyValuesInResultData(Map<String, Object> row, Map<String, String> aliases) {
		return keyNames.stream().map(k -> row.get(aliases.getOrDefault(k, k))).toList();
	}

	private static Map<String, Object> clean(Map<String, Object> r) {
		r.keySet().removeIf(k -> k.startsWith(PK_PREFIX));
		return r;
	}

	private void expandLazy(CqnExpand expand, CqnStructuredTypeRef ref, Map<String, String> mappingAliases,
			List<Map<String, Object>> rows) {
		for (Map<String, Object> row : rows) {
			injector(row, mappingAliases).injectInto(row, ref, expand.items(), expand.orderBy(), expand.top(),
					expand.skip(), expand.alias());
		}
	}

	private LazyAssociationLoaderInjector injector(Map<String, Object> row, Map<String, String> mappingAliases) {
		Map<String, Object> pkValues = new HashMap<>();
		root.keyElements().forEach(k -> {
			String keyName = k.getName();
			String displayName = mappingAliases.getOrDefault(keyName, keyName);
			pkValues.put(keyName, row.get(displayName));
		});

		return new LazyAssociationLoaderInjector((CdsEntity) root, pkValues);
	}

	private class LazyAssociationLoaderInjector {
		private CdsEntity entity;
		private Map<String, Object> keyValues;

		LazyAssociationLoaderInjector(CdsEntity entity, Map<String, Object> keyValues) {
			this.entity = entity;
			this.keyValues = keyValues;
		}

		private void injectInto(Map<String, Object> row, CqnStructuredTypeRef path, List<CqnSelectListItem> slis,
				List<CqnSortSpecification> orderBy, long top, long skip, Optional<String> alias) {
			CqnSelect query = queryByValues(path, slis, orderBy, top, skip, keyValues);
			Lazy loader = singleValued(entity, path) ? lazyRow(dataStore, query) : new LazyResultImpl(dataStore, query);
			String displayName = alias.orElse(path.lastSegment());
			DataUtils.putPath(row, displayName, loader);
		}

		private CqnSelect queryByValues(CqnStructuredTypeRef path, List<CqnSelectListItem> slis,
				List<CqnSortSpecification> orderBy, long top, long skip, Map<String, Object> keyValues) {
			if (!keyValues.keySet().containsAll(concreteKeyNames(root))) {
				throw new CdsException("Missing key values for entity " + root.getQualifiedName()
						+ ". Please add all keys to the projection.");
			}
			List<CqnReference.Segment> segments = new ArrayList<>();
			segments.add(refSegment(root.getQualifiedName(), matching(keyValues)));
			segments.addAll(path.segments());

			return SelectBuilder.from(typeRef(segments, null)).columns(slis).orderBy(orderBy).limit(top, skip);
		}
	}

	private boolean singleValued(CdsEntity entity, CqnStructuredTypeRef path) {
		CdsEntity e = entity;
		CdsElement association = null;
		for (Segment seg : path.segments()) {
			String assocName = seg.id();
			association = e.getAssociation(assocName);
			e = e.getTargetOf(assocName);
		}
		if (association == null) {
			throw new CdsException(
					"Missing association for Entity " + e.getName() + ", under Path " + path.toJson() + ".");
		}
		return isSingleValued(association.getType());
	}

	private CqnSelect queryByParams(CqnStructuredTypeRef path, List<CqnSelectListItem> items,
			List<CqnSortSpecification> orderBy, long top, long skip, Map<String, String> aliases) {
		List<CqnReference.Segment> segments = new ArrayList<>(path.segments().size() + 1);
		segments.add(refSegment(root.getQualifiedName(), pkFilter(aliases)));
		segments.addAll(path.segments());

		return SelectBuilder.from(typeRef(segments)).columns(items).orderBy(orderBy).limit(top, skip);
	}

	private CqnPredicate pkFilter(Map<String, String> aliases) {
		return switch (keyNames.size()) {
			case 0 -> CQL.TRUE;
			case 1 -> pkFilterSingle(aliases);
			default -> pkFilterList(aliases);
		};
	}

	private CqnPredicate pkFilterSingle(Map<String, String> pkAliases) {
		String pk = keyNames.get(0);

		return CQL.comparison(CQL.get(pk), EQ, CQL.param(pkAliases.getOrDefault(pk, pk)));
	}

	private CqnPredicate pkFilterList(Map<String, String> aliases) {
		int n = keyNames.size();
		List<CqnElementRef> fkRefs = new ArrayList<>(n);
		List<CqnParameter> pkParams = new ArrayList<>(n);
		keyNames.forEach(pk -> {
			fkRefs.add(CQL.get(pk));
			pkParams.add(CQL.param(aliases.getOrDefault(pk, pk)));
		});

		return CQL.comparison(CQL.list(fkRefs), EQ, CQL.list(pkParams));
	}

}
