/**************************************************************************
 * (C) 2019-2025 SAP SE or an SAP affiliate company. All rights reserved. *
 **************************************************************************/
package com.sap.cds.adapter.odata.v4.query.apply;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.apache.commons.lang3.StringUtils;
import org.apache.olingo.server.api.uri.UriResource;
import org.apache.olingo.server.api.uri.queryoption.apply.AggregateExpression.StandardMethod;
import org.apache.olingo.server.api.uri.queryoption.expression.Expression;

import com.sap.cds.adapter.odata.v4.query.ExpressionParser;
import com.sap.cds.ql.CQL;
import com.sap.cds.ql.ElementRef;
import com.sap.cds.ql.Value;
import com.sap.cds.reflect.CdsBaseType;
import com.sap.cds.reflect.CdsElement;
import com.sap.cds.reflect.CdsSimpleType;
import com.sap.cds.reflect.CdsType;
import com.sap.cds.services.utils.CdsErrorStatuses;
import com.sap.cds.services.utils.ErrorStatusException;
import com.sap.cds.services.utils.model.CdsAnnotations;
import com.sap.cds.util.CdsModelUtils;

public class ElementAggregator {
	private static final String UNIT_OR_CURRENCY = "unit or currency";

	private final ExpressionParser expressionParser;

	public ElementAggregator(ExpressionParser expressionParser) {
		this.expressionParser = expressionParser;
	}

	public Value<?> genericAggregate(Expression expr, StandardMethod standardMethod) {
		Value<?> value = (Value<?>) expressionParser.parseValue(expr);
		return toFunctionCall(value, standardMethod);
	}

	public Value<?> customAggregate(List<UriResource> path) {
		ElementRef<Object> ref = CQL.get(expressionParser.toSegmentList(path));

		return customAggregate(ref);
	}

	public Value<?> customAggregate(ElementRef<Object> ref) {
		if (ref.lastSegment().equals("$count")) {
			return CQL.count();
		}

		CdsElement element = CdsModelUtils.element(expressionParser.getRootType(), ref);
		CustomAggregate aggMethod = getAggregationMethodName(element);

		return toFunctionCall(ref, aggMethod);
	}

	private CustomAggregate getAggregationMethodName(CdsElement element) {
		String methodName = getAnnoValue(element, CdsAnnotations.AGGREGATION_DEFAULT);
		if (methodName != null) {
			String unitOrCurrencyRef = getAnnoValue(element, CdsAnnotations.SEMANTICS_UNIT_OR_CURRENCY_REF);
			return CustomAggregate.of(methodName, unitOrCurrencyRef);
		}

		if (CdsAnnotations.SEMANTICS_CURRENCY_CODE.isTrue(element) ||
				CdsAnnotations.SEMANTICS_UNIT_OF_MEASURE.isTrue(element)) {
			return CustomAggregate.of(UNIT_OR_CURRENCY);
		}

		throw new ErrorStatusException(CdsErrorStatuses.NO_CUSTOM_AGGREGATE_DEFINED, element);
	}

	Value<?> toFunctionCall(ElementRef<Object> ref, CustomAggregate aggMethod) {
		final String methodName = aggMethod.methodName();
		Value<?> aggFunc = switch (methodName) {
			case "AVG", "AVERAGE" -> ref.average();
			case "COUNT" -> CQL.count(ref);
			case "COUNT_DISTINCT" -> ref.countDistinct();
			case "MIN", "MAX", "SUM" -> typed(CQL.func(methodName, ref), ref);
			case UNIT_OR_CURRENCY ->
				typed(CQL.when(ref.max().isNull()).then(CQL.constant(""))
						 .when(ref.min().eq(ref.max())).then(ref.min())
						 .orElse(CQL.NULL), ref);
			default ->
				throw new ErrorStatusException(CdsErrorStatuses.UNKONWN_AGGREGATION_METHOD, aggMethod);
		};

		if (aggMethod.hasUnitOfMeasure()) {
			ElementRef<String> uom = aggMethod.unitOfMeasureRef();
			aggFunc = CQL.when(uom.min().eq(uom.max())).then(aggFunc).orElse(CQL.NULL);
		}
		return aggFunc;
	}

	private Value<?> typed(Value<?> functionCall, Value<?> ref) {
		type(ref).ifPresent(functionCall::type);
		return functionCall;
	}

	public Value<?> toFunctionCall(Value<?> value, StandardMethod standardMethod) {
		return switch (standardMethod) {
			case AVERAGE -> value.average();
			case COUNT_DISTINCT -> value.countDistinct();
			case MIN -> typed(value.min(), value);
			case MAX -> typed(value.max(), value);
			default -> CQL.func(standardMethod.name(), value);
		};
	}

	private Optional<CdsBaseType> type(Value<?> value) {
		if (value.isRef()) {
			CdsType t = CdsModelUtils.element(expressionParser.getRootType(), value.asRef()).getType();
			if (t.isSimple()) {
				return Optional.ofNullable(t.as(CdsSimpleType.class).getType());
			}
		}
		return Optional.empty();
	}

	private static String getAnnoValue(CdsElement element, CdsAnnotations annotation) {
		Object value = annotation.getOrDefault(element);
		if (value != null) {
			if (value instanceof Map m) {
				// Aggregation method
				return String.valueOf(m.get("#"));
			}
			// ref
			return value.toString();
		}
		return null;
	}

	static class CustomAggregate {
		String methodName;
		String unitOrCurrencyRef;

		CustomAggregate(String methodName, String unitOfMeasure) {
			this.methodName = methodName;
			this.unitOrCurrencyRef = unitOfMeasure;
		}

		static CustomAggregate of(String methodName, String uomRef) {
			return new CustomAggregate(methodName, uomRef);
		}

		static CustomAggregate of(String methodName) {
			return of(methodName, null);
		}

		String methodName() {
			return methodName;
		}

		ElementRef<String> unitOfMeasureRef() {
			return CQL.get(unitOrCurrencyRef);
		}

		boolean hasUnitOfMeasure() {
			return !StringUtils.isEmpty(unitOrCurrencyRef);
		}
	}

}
