/**************************************************************************
 * (C) 2019-2024 SAP SE or an SAP affiliate company. All rights reserved. *
 **************************************************************************/
package com.sap.cds.adapter.odata.v4.utils;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.apache.olingo.server.api.uri.UriInfo;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.sap.cds.CdsData;
import com.sap.cds.Result;
import com.sap.cds.Row;
import com.sap.cds.adapter.odata.v4.query.LimitLookup;
import com.sap.cds.adapter.odata.v4.query.NextLinkInfo;
import com.sap.cds.ql.CQL;
import com.sap.cds.ql.CdsDataException;
import com.sap.cds.ql.ElementRef;
import com.sap.cds.ql.Literal;
import com.sap.cds.ql.Predicate;
import com.sap.cds.ql.Select;
import com.sap.cds.ql.cqn.CqnElementRef;
import com.sap.cds.ql.cqn.CqnPredicate;
import com.sap.cds.ql.cqn.CqnSelect;
import com.sap.cds.ql.cqn.CqnSortSpecification;
import com.sap.cds.ql.cqn.CqnSortSpecification.Order;
import com.sap.cds.ql.cqn.CqnValue;
import com.sap.cds.ql.impl.SelectBuilder;
import com.sap.cds.reflect.CdsElement;
import com.sap.cds.reflect.CdsEntity;
import com.sap.cds.reflect.CdsModel;
import com.sap.cds.reflect.CdsService;
import com.sap.cds.reflect.CdsSimpleType;
import com.sap.cds.services.environment.CdsProperties.Query.Limit;
import com.sap.cds.services.request.RequestContext;
import com.sap.cds.services.runtime.CdsRuntime;
import com.sap.cds.services.utils.CdsErrorStatuses;
import com.sap.cds.services.utils.ErrorStatusException;
import com.sap.cds.services.utils.TenantAwareCache;
import com.sap.cds.util.CdsTypeUtils;
import com.sap.cds.util.transformations.LimitCalculator;

import static com.sap.cds.ql.cqn.CqnComparisonPredicate.Operator.GT;

public final class QueryLimitUtils {

	private static final ObjectMapper objectMapper = new ObjectMapper();

	private static final String VALUE = "v";
	private static final String KEY = "k";
	private static final String HAS_ASCENDING_ORDER = "a";
	private static final String ALREADY_READ = "r";
	private static final String TOKEN_CONTENT = "c";

	private static TenantAwareCache<LimitLookup, CdsModel> limitLookup;

	private final CdsEntity entity;

	private final boolean isReliablePaging;

	// derived from request
	private int top;
	private int skip;
	private boolean serverDrivenPaging;

	// derived from skipToken
	private int alreadyRead;
	private List<Map<String, Object>> sortedValues;

	public static void initialize(CdsRuntime runtime) {
		Limit config = runtime.getEnvironment().getCdsProperties().getQuery().getLimit();
		limitLookup = TenantAwareCache.create(
			() -> RequestContext.getCurrent(runtime).getUserInfo().getTenant(),
			() -> new LimitLookup(config),
			() -> RequestContext.getCurrent(runtime).getModel());
	}

	public QueryLimitUtils(CdsService service, CdsEntity entity, UriInfo uriInfo, Limit properties) {
		this.entity = entity;
		this.isReliablePaging = properties.getReliablePaging().isEnabled();
		initializeLimits(uriInfo, service, entity);
	}

	public void handlePagination(List<Select<?>> selects) {
		if (top < Integer.MAX_VALUE || skip > 0) {
			for (Select<?> select : selects) {
				LimitCalculator calc = LimitCalculator.of(select);
				calc.skip(skip);
				calc.top(top);
				select.limit(calc.top(), calc.skip());
			}
		}

		if (selects.size() == 1 && sortedValues != null && !sortedValues.isEmpty()) {
			// WHERE (FIRST > @FIRST OR (FIRST = @FIRST AND SECOND > @SECOND) ...)
			final CqnPredicate result;
			
			if (sortedValues.size() == 1) {
				result = comparison(sortedValues.get(0));
			} else {
				boolean allAscending = sortedValues.stream().allMatch(sv -> Boolean.TRUE.equals(sv.get(HAS_ASCENDING_ORDER)));
				
				if (allAscending) {
					result = rowValueComparison();
				} else {
					result = mixedComparison();
				}
			}

			SelectBuilder<?> select = (SelectBuilder<?>)selects.get(0);
			select.filter(result);
		}
	}

	private CqnPredicate mixedComparison() {
		Predicate result = CQL.FALSE;
		Map<String, Object> previous = null;
		for(Map<String, Object> entry : sortedValues) {
			CqnPredicate next = comparison(entry);

			if (previous != null) {
				String previousElementName = (String) previous.get(KEY);
				Object previousValue = previous.get(VALUE);
				next = CQL.get(previousElementName).eq(previousValue).and(next);
			}
			previous = entry;
			result = result.or(next);
		}
		return result;
	}

	private static CqnPredicate comparison(Map<String, Object> sv) {
		ElementRef<Object> ref = CQL.get((String) sv.get(KEY));
		Object value = sv.get(VALUE);
		boolean asc = (boolean) sv.get(HAS_ASCENDING_ORDER);
		
		return asc ? ref.gt(value) : ref.lt(value);
	}

	private CqnPredicate rowValueComparison() {
		int n = sortedValues.size();
		List<CqnElementRef> refs = new ArrayList<>(n);
		List<CqnValue> vals = new ArrayList<>(n);
		sortedValues.forEach(sv -> {
			CqnElementRef ref = CQL.get((String) sv.get(KEY));
			Literal<Object> val = CQL.val(sv.get(VALUE));
			refs.add(ref);
			vals.add(val);
		});

		return CQL.comparison(CQL.list(refs), GT, CQL.list(vals));
	}

	public NextLinkInfo generateNextLink(Result result, CqnSelect select) {
		if (result.rowCount() >= top && serverDrivenPaging) {
			int nextAlreadyRead = alreadyRead + (int) result.rowCount();
			if (isReliablePaging && result.rowCount() > 0 && select != null) {
				Row row = result.list().get((int) result.rowCount() - 1);
				List<Map<String, Object>> values = new ArrayList<>();
				for (CqnSortSpecification spec : select.orderBy()) {
					if (spec.value().isRef()) {
						String key = spec.value().asRef().displayName();
						if (getValidElement(key) != null && row.containsKey(key)) {
							Map<String, Object> value = new HashMap<>();
							value.put(HAS_ASCENDING_ORDER, Order.ASC == spec.order());
							value.put(KEY, key);
							value.put(VALUE, row.get(key));
							values.add(value);
						}
					}
				}

				if (values.size() == select.orderBy().size()) {
					CdsData tokenMap = CdsData.create();
					tokenMap.put(ALREADY_READ, nextAlreadyRead);
					tokenMap.put(TOKEN_CONTENT, values);
					String skipToken = Base64.getEncoder().encodeToString(tokenMap.toJson().getBytes(StandardCharsets.UTF_8));
					return new NextLinkInfo(skipToken);
				}
			}
			return new NextLinkInfo(String.valueOf(nextAlreadyRead));
		}
		return null;
	}

	private void initializeLimits(UriInfo uriInfo, CdsService service, CdsEntity entity) {
		serverDrivenPaging = false;
		top = Integer.MAX_VALUE;
		skip = 0;

		if (uriInfo.getTopOption() != null) {
			top = uriInfo.getTopOption().getValue();
		}

		if (uriInfo.getSkipOption() != null) {
			skip = uriInfo.getSkipOption().getValue();
		}

		if (uriInfo.getSkipTokenOption() != null) {
			initializeSkipToken(uriInfo.getSkipTokenOption().getValue());

			// only add to skip, if skipToken is numeric
			if (sortedValues == null) {
				skip += alreadyRead;
			} else {
				skip = 0; // $skip is applied only on first page
			}

			// always reduce top, by already read data on previous pages
			if (top != Integer.MAX_VALUE) {
				top = Math.max(top - alreadyRead, 0);
			}
		}

		int defaultTop = limitLookup.findOrCreate().getDefaultValue(service, entity);
		if (defaultTop > 0 && top == Integer.MAX_VALUE) {
			top = defaultTop;
			serverDrivenPaging = true;
		}

		int maxTop = limitLookup.findOrCreate().getMaxValue(service, entity);
		if (maxTop > 0 && top > maxTop) {
			top = maxTop;
			serverDrivenPaging = true;
		}
	}

	@SuppressWarnings("unchecked")
	private void initializeSkipToken(String skipToken) {
		try {
			if (StringUtils.isNumeric(skipToken)) {
				alreadyRead = Integer.parseInt(skipToken);
			} else if (isReliablePaging) {
				TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {};
				Map<String, Object> token = objectMapper.readValue(Base64.getDecoder().decode(skipToken), typeRef);
				alreadyRead = (int) token.get(ALREADY_READ);
				sortedValues = (List<Map<String, Object>>) token.get(TOKEN_CONTENT);
				for (Map<String, Object> sortedValue : sortedValues) {
					String elementName = (String) sortedValue.get(KEY);
					Object value = sortedValue.get(VALUE);
					Object converted = convert(elementName, value);
					sortedValue.put(VALUE, converted);
				}
			}
		} catch (Exception e) {
			throw new ErrorStatusException(CdsErrorStatuses.MALFORMED_SKIPTOKEN, e);
		}
	}

	private Object convert(String elementName, Object value) {
		CdsElement element = getValidElement(elementName);
		if (element != null) {
			try {
				// toString() is safe, because we used a simple ObjectMapper to serialize to JSON
				// all primitive JSON types are correctly represented after running them through toString()
				String stringValue = value == null ? null : value.toString();
				return CdsTypeUtils.parse(element.getType().as(CdsSimpleType.class).getType(), stringValue);
			} catch (CdsDataException e) {
				throw new ErrorStatusException(CdsErrorStatuses.MALFORMED_SKIPTOKEN, e);
			}
		} else {
			throw new ErrorStatusException(CdsErrorStatuses.MALFORMED_SKIPTOKEN);
		}
	}

	private CdsElement getValidElement(String elementName) {
		return entity.findElement(elementName).filter(e -> e.getType().isSimple() && !e.isVirtual()).orElse(null);
	}

}
