/************************************************************************
 * © 2019-2022 SAP SE or an SAP affiliate company. All rights reserved. *
 ************************************************************************/
package com.sap.cds.impl;

import static com.sap.cds.ResultBuilder.deletedRows;
import static com.sap.cds.ResultBuilder.insertedRows;
import static com.sap.cds.impl.builder.model.CqnParam.params;
import static com.sap.cds.impl.docstore.DocStoreUtils.targetsDocStore;
import static com.sap.cds.impl.parser.token.CqnBoolLiteral.TRUE;
import static com.sap.cds.util.CdsModelUtils.entity;
import static com.sap.cds.util.CdsModelUtils.keyNames;
import static com.sap.cds.util.CdsModelUtils.CascadeType.DELETE;
import static com.sap.cds.util.CqnStatementUtils.inlineCountQuery;
import static com.sap.cds.util.CqnStatementUtils.isSelectStar;
import static com.sap.cds.util.CqnStatementUtils.moveKeyValuesToWhere;
import static com.sap.cds.util.CqnStatementUtils.rowType;
import static com.sap.cds.util.DataUtils.isDeep;
import static com.sap.cds.util.PathExpressionResolver.resolvePath;
import static java.lang.Math.min;
import static java.util.Arrays.stream;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static java.util.Objects.hash;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import com.google.common.collect.Streams;
import com.sap.cds.CdsDataStore;
import com.sap.cds.CdsException;
import com.sap.cds.DataStoreConfiguration;
import com.sap.cds.Result;
import com.sap.cds.ResultBuilder;
import com.sap.cds.Row;
import com.sap.cds.SessionContext;
import com.sap.cds.impl.EntityCascader.EntityKeys;
import com.sap.cds.impl.EntityCascader.EntityOperation;
import com.sap.cds.impl.EntityCascader.EntityOperation.Operation;
import com.sap.cds.impl.EntityCascader.EntityOperations;
import com.sap.cds.impl.builder.model.Conjunction;
import com.sap.cds.ql.CQL;
import com.sap.cds.ql.Delete;
import com.sap.cds.ql.Insert;
import com.sap.cds.ql.Select;
import com.sap.cds.ql.Update;
import com.sap.cds.ql.cqn.CqnAnalyzer;
import com.sap.cds.ql.cqn.CqnDelete;
import com.sap.cds.ql.cqn.CqnInsert;
import com.sap.cds.ql.cqn.CqnPredicate;
import com.sap.cds.ql.cqn.CqnSelect;
import com.sap.cds.ql.cqn.CqnUpdate;
import com.sap.cds.ql.cqn.CqnUpsert;
import com.sap.cds.ql.cqn.CqnXsert;
import com.sap.cds.ql.impl.CqnNormalizer;
import com.sap.cds.ql.impl.DeepInsertSplitter;
import com.sap.cds.ql.impl.DeepUpdateSplitter;
import com.sap.cds.ql.impl.DeleteBuilder;
import com.sap.cds.ql.impl.XsertBuilder.UpsertBuilder;
import com.sap.cds.reflect.CdsEntity;
import com.sap.cds.reflect.CdsModel;
import com.sap.cds.reflect.CdsStructuredType;
import com.sap.cds.util.CqnStatementUtils;
import com.sap.cds.util.CqnStatementUtils.Count;
import com.sap.cds.util.DataUtils;
import com.sap.cds.util.ProjectionProcessor;

public class CdsDataStoreImpl implements CdsDataStore {
	private final CqnValidator cqnValidator;
	private final ConnectedClient connectedClient;
	private final CqnNormalizer cqnNormalizer;
	private final CqnAnalyzer cqnAnalyzer;
	private final DataUtils dataUtils;
	private final Context context;
	private final CdsModel model;
	private final ProjectionProcessor projectionProcessor;
	private static final Logger logger = LoggerFactory.getLogger(CdsDataStoreImpl.class);
	private static final TimingLogger timed = new TimingLogger(logger);

	public CdsDataStoreImpl(Context context, ConnectedDataStoreConnector connector) {
		this.context = context;
		this.model = context.getCdsModel();
		this.cqnValidator = CqnValidator.create(context);
		this.connectedClient = connector.create(context);
		this.dataUtils = DataUtils.create(context::getSessionContext,
				context.getDbContext().getCapabilities().timestampPrecision());
		this.connectedClient.setSessionContext(context.getSessionContext());
		this.cqnNormalizer = new CqnNormalizer(context);
		this.cqnAnalyzer = CqnAnalyzer.create(model);
		this.projectionProcessor = ProjectionProcessor.create(model, cqnAnalyzer, dataUtils);
	}

	@Override
	public Result execute(CqnSelect select, Object... paramValues) {
		return execute(select, toIndexMap(paramValues));
	}

	@Override
	public Result execute(CqnSelect select, Map<String, Object> cqnParameterValues) {
		return timed.debug(() -> {
			CdsStructuredType rowType = null;
			if (select.from().isRef() && isSelectStar(select.items()) && select.excluding().isEmpty()) {
				rowType = entity(model, select.from().asRef());
			}
			CqnSelect normSelect = cqnNormalizer.normalize(select);
			cqnValidator.validate(normSelect, connectedClient.capabilities());
			if (rowType == null) {
				rowType = rowType(model, normSelect);
			}
			normSelect = cqnNormalizer.resolveForwardMappedAssocs(normSelect);
			PreparedCqnStatement pcqn = connectedClient.prepare(normSelect);
			ResultBuilder result = connectedClient
					.executeQuery(pcqn, cqnParameterValues, this, normSelect.getLock().isPresent()).rowType(rowType);

			if (normSelect.hasInlineCount()) {
				long rowCount = result.result().rowCount();
				if (normSelect.hasLimit() && requiresInlineCountQuery(normSelect.top(), normSelect.skip(), rowCount)) {
					result.inlineCount(getInlineCount(normSelect, cqnParameterValues));
				} else {
					result.inlineCount(rowCount);
				}
			}

			return result.result();
		}, "CQN >>{}<<", () -> new Object[] { safeToJson(select, context.getDataStoreConfiguration()) });
	}

	@Override
	public Result execute(CqnSelect select, Iterable<Map<String, Object>> valueSets, int maxBatchSize) {
		int valueSetSize = Iterables.size(valueSets);
		if (valueSetSize == 1) {
			return execute(select, valueSets.iterator().next());
		}
		if (!select.orderBy().isEmpty() && valueSetSize > maxBatchSize) {
			throw new UnsupportedOperationException(
					"Order by is not supported when query is executed in multiple batches");
		}
		List<Row> rows = new ArrayList<>(valueSetSize);
		Iterator<List<Map<String, Object>>> partitions = Iterators.partition(valueSets.iterator(), maxBatchSize);
		while (partitions.hasNext()) {
			CqnSelect batchSelect = CqnStatementUtils.batchSelect(select, partitions.next());
			List<Row> result = execute(batchSelect).list();
			rows.addAll(result);
		}
		// TODO Support streaming
		return ResultBuilder.selectedRows(rows).result();
	}

	private long getInlineCount(CqnSelect select, Map<String, Object> cqnParameterValues) {
		CqnSelect inlineCountQuery = inlineCountQuery(select);
		PreparedCqnStatement pcqn = connectedClient.prepare(inlineCountQuery);
		Result result = connectedClient.executeQuery(pcqn, cqnParameterValues, this, false).result();

		return result.single(Count.class).getCount();
	}

	@VisibleForTesting
	static String safeToJson(CqnSelect select, DataStoreConfiguration config) {
		boolean doLogValues = config.getProperty(DataStoreConfiguration.LOG_CQN_VALUES, false);
		try {
			if (!doLogValues) {
				return CqnStatementUtils.anonymizeStatement(select).toJson();
			}
			return select.toJson();
		} catch (RuntimeException ex) {
			logger.error("cannot serialize CQN statement");
			return "Unserializable CQN";
		}
	}

	@VisibleForTesting
	static boolean requiresInlineCountQuery(long top, long skip, long rowCount) {
		return skip > 0 || top <= rowCount;
	}

	@Override
	public Result execute(CqnDelete delete, Object... paramValues) {
		return execute(delete, toIndexMap(paramValues));
	}

	@Override
	public Result execute(CqnDelete delete, Map<String, Object> namedValues) {
		return execute(delete, singletonList(namedValues));
	}

	@Override
	public Result execute(CqnDelete delete, Iterable<Map<String, Object>> valueSets) {
		delete = projectionProcessor.resolve(delete);
		delete = cqnNormalizer.normalize(delete);
		CdsEntity target = entity(model, delete.ref());
		try {
			DeleteCascader.create(target).from(delete.ref()).where(delete.where())
					.cascade(d -> bulkDelete(d, valueSets, true));
			return bulkDelete(delete, valueSets, false);
		} catch (UnsupportedOperationException e) {
			// fallback for cyclic models and subqueries in where
			target = model.getEntity(delete.ref().firstSegment());
			Set<EntityKeys> entities = cascade(target, delete.where(), valueSets);
			Result result = bulkDelete(delete, valueSets, false);
			delete(entities.stream().map(k -> EntityOperation.delete(k, context.getSessionContext())), true);
			return result;
		}
	}

	private Set<EntityKeys> cascade(CdsEntity target, Optional<CqnPredicate> filter,
			Iterable<? extends Map<String, Object>> valueSets) {
		EntityCascader cascader = EntityCascader.from(this, target).where(filter);
		Set<EntityKeys> keySets;
		if (valueSets.iterator().hasNext()) {
			keySets = Streams.stream(valueSets).flatMap(v -> cascader.with(v).cascade(DELETE).stream())
					.collect(Collectors.toSet());
		} else {
			keySets = cascader.cascade(DELETE);
		}
		return keySets;
	}

	private void upsert(Stream<EntityOperation> operations) {
		List<EntityOperation> notInDatastore = update(operations);
		insert(notInDatastore.stream());
	}

	private void insert(Stream<EntityOperation> ops) {
		ops.collect(groupingBy(EntityOperation::targetEntity)).forEach((entity, data) -> {
			Iterator<Row> rows = execute(Insert.into(entity).entries(data)).iterator();
			for (EntityOperation op : data) {
				op.inserted(rows.next());
			}
		});
	}

	private List<EntityOperation> update(Stream<EntityOperation> ops) {
		List<EntityOperation> notFound = new ArrayList<>();
		ops.collect(groupingBy(o -> hash(o.targetEntity().getQualifiedName(), o.updateValues().keySet())))
				.forEach((g, updates) -> {
					CdsEntity entity = updates.get(0).targetEntity();
					List<Map<String, Object>> entries = updates.stream().map(EntityOperation::updateValues)
							.collect(toList());
					Result result = execute(Update.entity(entity).entries(entries));
					if (result.rowCount() > 0) {
						Iterator<Row> resultIter = result.iterator();
						Iterator<EntityOperation> updateIter = updates.iterator();
						for (int i = 0; i < result.batchCount(); i++) {
							EntityOperation u = updateIter.next();
							long rowCount = result.rowCount(i);
							u.updated(resultIter.next(), rowCount);
							if (rowCount == 0) {
								notFound.add(u);
							}
						}
					} else {
						notFound.addAll(updates);
					}
				});
		return notFound;
	}

	private void delete(Stream<EntityOperation> ops, boolean rollbackOnFail) {
		Map<CdsEntity, List<EntityOperation>> keyMap = ops.collect(groupingBy(EntityOperation::targetEntity));
		keyMap.forEach((entity, op) -> {
			if (!op.isEmpty()) {
				Set<String> keyNames = op.iterator().next().targetKeys().keySet();
				CqnDelete delete = DeleteBuilder.from(entity.getQualifiedName()).matching(params(keyNames));
				bulkDelete(delete, op, rollbackOnFail);
				op.forEach(EntityOperation::deleted);
			}
		});
	}

	private Result bulkDelete(CqnDelete delete, Iterable<? extends Map<String, Object>> valueSets,
			boolean rollbackOnFail) {
		cqnValidator.validate(delete);
		delete = resolvePath(model, delete);
		delete = projectionProcessor.resolve(delete);
		PreparedCqnStatement pcqn = connectedClient.prepare(delete);
		List<Map<String, Object>> parameterValues = new ArrayList<>();
		valueSets.forEach(parameterValues::add);
		try {
			int[] deleteCount = connectedClient.executeUpdate(pcqn, parameterValues);

			return deletedRows(deleteCount).result();
		} catch (Exception e) {
			if (rollbackOnFail) {
				connectedClient.setRollbackOnly();
			}
			throw e;
		}
	}

	@Override
	public Result execute(CqnInsert insert) {
		insert = isDraftEnabled(insert);
		insert = cqnNormalizer.normalize(insert);
		return deepInsert(insert, false);
	}

	private CqnInsert isDraftEnabled(CqnInsert insert) {
		CdsEntity entity = model.getEntity(insert.ref().firstSegment());
		if (entity.findAnnotation("odata.draft.enabled").isPresent()) {
			insert = projectionProcessor.resolve(insert);
		}
		return insert;
	}

	@Override
	public Result execute(CqnUpsert upsert) {
		if (upsert.entries().isEmpty()) {
			return insertedRows(emptyList()).result();
		}
		upsert = cqnNormalizer.normalize(upsert);
		CdsEntity entity = entity(model, upsert.ref());

		if (((UpsertBuilder) upsert).byDeleteAndInsert()) {
			deleteByKeys(entity, upsert.entries());
			return deepInsert(upsert, true);
		}
		return upsert(upsert, entity);
	}

	private Result upsert(CqnUpsert upsert, CdsEntity entity) {
		List<Map<String, Object>> entries = upsert.entries();
		dataUtils.prepareForInsert(entity, entries);
		if (isDeep(entity, upsert.entries())) {
			throw new UnsupportedOperationException("Deep DB Upserts are not supported");
		}
		PreparedCqnStatement pcqn = connectedClient.prepare(upsert);
		connectedClient.executeUpdate(pcqn, upsert.entries());

		return insertedRows(entries).result();
	}

	private Result deleteByKeys(CdsEntity entity, Iterable<Map<String, Object>> keyValues) {
		CqnDelete delete = Delete.from(entity).matching(keyNames(entity) //
				.stream().collect(toMap(k -> k, k -> CQL.param(k))));
		DataUtils.normalizedUuidKeys(entity, keyValues);
		return execute(delete, keyValues);
	}

	private Result deepInsert(CqnXsert xsert, boolean rollbackOnFail) {
		CdsEntity entity = model.getEntity(xsert.ref().firstSegment());
		List<Map<String, Object>> entries = xsert.entries();
		dataUtils.prepareForInsert(entity, entries);
		List<Insert> inserts = new DeepInsertSplitter(entity, context.getSessionContext()).split(entries);
		boolean isRollbackOnly = rollbackOnFail || inserts.size() > 1;
		inserts.forEach(insert -> {
			cqnValidator.validate(insert);
			insert = projectionProcessor.resolve(insert);
			CdsEntity target = model.getEntity(insert.ref().firstSegment());
			DataUtils.resolvePaths(target, insert.entries());
			if (isDeep(target, insert.entries())) {
				deepInsert(insert, true);
				return;
			}

			// process @cds.on.insert again as element might be excluded by projection
			dataUtils.processOnInsert(target, insert.entries());

			PreparedCqnStatement pcqn = connectedClient.prepare(insert);
			try {
				connectedClient.executeUpdate(pcqn, insert.entries());
			} catch (Exception e) {
				if (isRollbackOnly) {
					connectedClient.setRollbackOnly();
				}
				throw e;
			}
		});

		return insertedRows(entries).result();
	}

	@Override
	public Result execute(CqnUpdate update, Object... paramValues) {
		return execute(update, toIndexMap(paramValues));
	}

	@Override
	public Result execute(CqnUpdate update, Map<String, Object> namedValues) {
		return execute(update, namedValues.isEmpty() ? emptyList() : singletonList(namedValues));
	}

	@Override
	public Result execute(CqnUpdate update, Iterable<Map<String, Object>> valueSets) {
		cqnValidator.validate(update);
		CdsEntity entity = cqnAnalyzer.analyze(update.ref()).targetEntity();
		dataUtils.prepareForUpdate(entity, update.entries());
		if (!DataUtils.hasNonKeyValues(entity, update.data())) {
			Result count = selectCountAll(update, entity, update);
			update.entries().forEach(Map::clear);
			return count;
		}
		CqnUpdate normUpdate = cqnNormalizer.normalize(update);
		if (isDeep(entity, normUpdate.entries()) || !DataUtils.uniformData(entity, normUpdate.entries())) {
			Map<String, Object> targetKeys = cqnAnalyzer.analyze(update).targetKeyValues();
			return deepUpdate(entity, normUpdate, targetKeys);
		}
		CqnUpdate resolvedUpdate = projectionProcessor.resolve(normUpdate);
		if (resolvedUpdate != normUpdate) {
			entity = cqnAnalyzer.analyze(resolvedUpdate.ref()).targetEntity();
			dataUtils.prepareForUpdate(entity, resolvedUpdate.entries());
			if (isDeep(entity, resolvedUpdate.entries())) {
				Map<String, Object> targetKeys = cqnAnalyzer.analyze(update).targetKeyValues();
				return deepUpdate(entity, resolvedUpdate, targetKeys);
			}
		}
		int[] updateCount = flatUpdate(resolvedUpdate, entity, valueSets);
		List<Map<String, Object>> entries = update.entries();
		if (entries.size() == 1) {
			entries = filledList(updateCount.length, entries.get(0));
		}

		return batchUpdateResult(entries, updateCount);
	}

	private static List<Map<String, Object>> filledList(int length, Map<String, Object> entry) {
		List<Map<String, Object>> entries;
		entries = new ArrayList<>(length);
		for (int i = 0; i < length; i++) {
			entries.add(entry);
		}
		return entries;
	}

	private static Result batchUpdateResult(List<Map<String, Object>> entries, int[] updateCount) {
		return batchUpdateResult(entries, stream(updateCount).asLongStream().toArray());
	}

	private static Result batchUpdateResult(List<Map<String, Object>> entries, long[] updateCount) {
		int size = entries.size();
		int length = updateCount.length;
		ResultBuilder builder = ResultBuilder.batchUpdate();
		for (int i = 0; i < length; i++) {
			builder.addUpdatedRows(updateCount[i], entries.get(min(size, i)));
		}

		return builder.result();
	}

	private int[] flatUpdate(CqnUpdate update, CdsEntity entity, Iterable<Map<String, Object>> valueSets) {
		List<Map<String, Object>> parameterValues = mergeParams(update.entries(), valueSets);
		if (!targetsDocStore(entity)) {
			moveKeyValuesToWhere(entity, update, true);
		}

		PreparedCqnStatement pcqn = connectedClient.prepare(update);

		return connectedClient.executeUpdate(pcqn, parameterValues);
	}

	private Result deepUpdate(CdsEntity entity, CqnUpdate update, Map<String, Object> targetKeys) {
		DeepUpdateSplitter updateSplitter = new DeepUpdateSplitter(this);
		EntityOperations operations = updateSplitter.computeOperations(entity, update, targetKeys);
		try {
			delete(operations.filter(Operation.DELETE), false);
			insert(operations.filter(Operation.INSERT));
			upsert(operations.filter(Operation.UPSERT));
			update(operations.filter(Operation.UPDATE));
		} catch (Exception e) {
			connectedClient.setRollbackOnly();
			throw e;
		}
		if (operations.entries().size() == 1 && operations.updateCount().length > 1) {
			return searchedUpdateResult(operations);
		}
		return batchUpdateResult(operations.entries(), operations.updateCount());
	}

	private Result searchedUpdateResult(EntityOperations operations) {
		Map<String, Object> data = operations.entries().get(0);
		return ResultBuilder.updatedRows(Arrays.stream(operations.updateCount()).sum(), data).result();
	}

	private Result selectCountAll(CqnUpdate update, CdsEntity entity, CqnUpdate resolvedUpdate) {
		Set<String> keys = keyNames(entity);
		Select<?> countQuery = CqnStatementUtils.countAll(update);
		long[] rowCount = new long[resolvedUpdate.entries().size()];
		int i = 0;
		for (Map<String, Object> entry : update.entries()) {
			CqnPredicate where = update.where().orElse(TRUE);
			where = Conjunction.and(where, update.elements().filter(keys::contains)
					.map(key -> CQL.get(key).eq(CQL.param(key))).collect(Conjunction.and()));
			countQuery.where(where);
			rowCount[i] = execute(countQuery, entry).single().as(Count.class).getCount();
			i++;
		}
		return batchUpdateResult(filledList(rowCount.length, new HashMap<>()), rowCount);
	}

	private static List<Map<String, Object>> mergeParams(List<Map<String, Object>> updateData,
			Iterable<Map<String, Object>> valueSets) {
		List<Map<String, Object>> paramVals = new ArrayList<>();
		if (!valueSets.iterator().hasNext()) {
			// no parameter set
			updateData.forEach(v -> paramVals.add(DataUtils.copyMap(v)));
			return paramVals;
		}
		valueSets.forEach(v -> paramVals.add(DataUtils.copyMap(v)));
		if (updateData.size() == 1) {
			// (mass) update with one data set
			Map<String, Object> data = updateData.get(0);
			paramVals.forEach(p -> p.putAll(data));
			return paramVals;
		}
		// batch update of multiple entities
		if (updateData.size() == paramVals.size()) {
			// with parameter set for each entry
			Iterator<Map<String, Object>> keyIter = updateData.iterator();
			paramVals.forEach(p -> p.putAll(keyIter.next()));
			return paramVals;
		}
		throw new CdsException("Batch update failed: Parameter value list size (" + paramVals.size()
				+ ") does not match batch size (" + updateData.size() + ")");
	}

	private static Map<String, Object> toIndexMap(Object... paramValues) {
		Map<String, Object> parameters = new HashMap<>();
		for (int i = 0; i < paramValues.length; i++) {
			parameters.put(String.valueOf(i), paramValues[i]);
		}
		return parameters;
	}

	@Override
	public SessionContext getSessionContext() {
		return context.getSessionContext();
	}

	@Override
	public void setSessionContext(SessionContext session) {
		this.context.setSessionContext(session);
		this.connectedClient.setSessionContext(session);
	}
}
