/**************************************************************************
 * (C) 2019-2024 SAP SE or an SAP affiliate company. All rights reserved. *
 **************************************************************************/
package com.sap.cds.adapter.odata.v2.utils;

import static com.sap.cds.ql.CQL.sort;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import org.apache.olingo.odata2.api.edm.EdmException;
import org.apache.olingo.odata2.api.uri.SelectItem;
import org.apache.olingo.odata2.api.uri.UriInfo;

import com.sap.cds.ql.CQL;
import com.sap.cds.ql.ElementRef;
import com.sap.cds.ql.Select;
import com.sap.cds.ql.Value;
import com.sap.cds.ql.cqn.CqnElementRef;
import com.sap.cds.ql.cqn.CqnSelectListValue;
import com.sap.cds.ql.cqn.CqnSortSpecification;
import com.sap.cds.ql.cqn.CqnValue;
import com.sap.cds.reflect.CdsAnnotatable;
import com.sap.cds.reflect.CdsAnnotation;
import com.sap.cds.reflect.CdsBaseType;
import com.sap.cds.reflect.CdsElement;
import com.sap.cds.reflect.CdsEntity;
import com.sap.cds.reflect.CdsSimpleType;
import com.sap.cds.reflect.CdsType;
import com.sap.cds.services.ServiceException;
import com.sap.cds.services.utils.CdsErrorStatuses;
import com.sap.cds.services.utils.ErrorStatusException;
import com.sap.cds.util.CdsModelUtils;

public class AggregateTransformation {

	private static final String AGGREGATION_DEFAULT = "Aggregation.default";
	private static final String ANALYTICS_MEASURE = "Analytics.Measure";
	private static final String SUPPORTED_RESTRICTIONS = "Aggregation.ApplySupported.PropertyRestrictions";
	private static final String SAP_AGGREGATION_ROLE = "sap.aggregation-role";

	public static final String AGGREGATE_ID = "ID__";

	private final CdsEntity target;
	private final Select<?> select;
	private final UriInfo uriInfo;

	private List<CqnValue> dimensions;
	private List<CqnSelectListValue> selectListItems;

	public AggregateTransformation(CdsEntity target, Select<?> select, UriInfo uriInfo) {
		this.target = target;
		this.select = select;
		this.uriInfo = uriInfo;
	}

	public boolean applyAggregation() {
		boolean applied = false;
		if (isAggregateEntity()) {
			dimensions = new ArrayList<>();
			selectListItems = new ArrayList<>();
			uriInfo.getSelect().forEach(this::collectListItems);

			select.columns(selectListItems);
			select.groupBy(dimensions);
			select.orderBy(getOrderBy());

			applied = true;
		}
		// $count
		if (uriInfo.isCount()) {
			select.columns(CQL.count().as("count"));
		}
		return applied;
	}

	private void collectListItems(SelectItem i) {
		try {
			String itemName = i.getProperty().getName();
			// exclude technical ID__ from select list as it is not in CDS Model
			if (isAggregateID(itemName, target)) {
				return;
			}
			target.getQualifier();
			CqnElementRef ref = CQL.get(itemName);
			CdsElement element = target.getElement(itemName);

			if (isMeasure(element)) {
				selectListItems.add(toFunctionCall(element).as(itemName));
			} else {
				// if element is not an aggregate, add it to group by
				selectListItems.add(ref);
				dimensions.add(ref);
			}
		} catch (EdmException e) {
			throw new ServiceException(e);
		}
	}

	private Value<?> toFunctionCall(CdsElement element) {
		ElementRef<?> ref = CQL.get(element.getName());
		String methodName = getAggregation(element);
		Value<?> functionCall;
		switch (methodName) {
			case "AVG":
			case "AVERAGE":
				functionCall = ref.average();
				break;
			case "COUNT":
				functionCall = CQL.func(methodName, ref).type(Long.class);
				break;
			case "COUNT_DISTINCT":
				functionCall = ref.countDistinct();
				break;
			case "MAX":
			case "MIN":
			case "SUM":
				functionCall = CQL.func(methodName, ref);
				type(ref).ifPresent(functionCall::type);
				break;
			default:
				throw new ErrorStatusException(CdsErrorStatuses.UNKONWN_AGGREGATION_METHOD, methodName);
		}
		return functionCall;
	}

	private Optional<CdsBaseType> type(Value<?> value) {
		if (value.isRef()) {
			CdsElement element = CdsModelUtils.element(target, value.asRef());
			CdsType t = element.getType();
			if (t.isSimple()) {
				return Optional.ofNullable(t.as(CdsSimpleType.class).getType());
			}
		}
		return Optional.empty();
	}

	private boolean isAggregateEntity() {
		return getAnnotatedValue(target, SUPPORTED_RESTRICTIONS, false);
	}

	private boolean isMeasure(CdsElement element) {
		boolean isAnalyticsMeasure = getAnnotatedValue(element, ANALYTICS_MEASURE, false);
		if (isAnalyticsMeasure) {
			return true;
		}
		return "measure".equals(getAnnotatedValue(element, SAP_AGGREGATION_ROLE, ""));
	}

	private String getAggregation(CdsElement element) {
		Map<String, String> annotatedValue = getAnnotatedValue(element, AGGREGATION_DEFAULT,
				Collections.singletonMap("#", "#"));
		return annotatedValue.get("#");
	}

	private <T> T getAnnotatedValue(CdsAnnotatable annotatable, String annotation, T fallBackValue) {
		try {
			return annotatable.<T>findAnnotation(annotation).map(CdsAnnotation::getValue).orElse(fallBackValue);
		} catch (ClassCastException ex) {
			throw new ServiceException("The type of annotation value for " + annotatable + " is not a "
					+ fallBackValue.getClass().getName(), ex);
		}
	}

	private List<CqnSortSpecification> getOrderBy() {
		return select.orderBy().stream().map(this::getItem).collect(Collectors.toList());
	}

	private CqnSortSpecification getItem(CqnSortSpecification orderByItem) {
		return select.items().stream()
				.filter(item -> item.asValue().displayName().equals(orderByItem.value().asRef().displayName()))
				.findFirst()
				.map(sli -> sort(sli.asValue().value(), orderByItem.order()))
				.orElseGet(() -> orderByItem);
	}

	public static boolean isAggregateID(String element, CdsEntity entity) {
		return AGGREGATE_ID.equals(element) && !(entity.findElement(element).isPresent());
	}
}
