/**************************************************************************
 * (C) 2019-2024 SAP SE or an SAP affiliate company. All rights reserved. *
 **************************************************************************/
package com.sap.cds.services.impl.authorization;

import static com.sap.cds.services.utils.model.CdsModelUtils.$USER;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import org.apache.commons.lang3.tuple.MutablePair;

import com.sap.cds.ql.CQL;
import com.sap.cds.ql.Predicate;
import com.sap.cds.ql.Select;
import com.sap.cds.ql.Value;
import com.sap.cds.ql.cqn.CqnComparisonPredicate.Operator;
import com.sap.cds.ql.cqn.CqnMatchPredicate;
import com.sap.cds.ql.cqn.CqnPredicate;
import com.sap.cds.ql.cqn.CqnReference.Segment;
import com.sap.cds.ql.cqn.CqnValue;
import com.sap.cds.ql.cqn.Modifier;
import com.sap.cds.services.request.UserInfo;
import com.sap.cds.services.utils.StringUtils;

public class PredicateResolver {

	static final Predicate UNRESTRICTED_PRED = CQL.TRUE;
	static final Predicate RESTRICTED_PRED = CQL.FALSE;


	// Handling of empty (or undefined) attributes:
	// - a user w/o role assignment will also have an empty attribute list
	// - an IdP-mapped attribute might be empty even if 'valueRequired:true' has been specified for the attribute

	// Also note that 'valueRequired:true' is the default for attributes.
	// This means: potentially unrestricted attributes need to be configured explicitly.

	// case emptyAttributesAreRestricted == true (current and recommended default):
	// 'valueRequired:true' can be modeled by '$user.attr <op> value'
	// 'valueRequired:false' can be modeled by '$user.attr <op> value or $user.attr is null'

	// case emptyAttributesAreRestricted == false (meant as kill switch):
	// 'valueRequired:true' can be modeled by '$user.attr <op> value and $user.attr is not null' (also covering empty IdP values)
	// 'valueRequired:false' can be modeled by '$user.attr <op> value'


	private final boolean emptyAttributesAreRestricted;

	// The cached unresolved, but parsed predicate from the model
	private final CqnPredicate unresolvedPred;

	private PredicateResolver(CqnPredicate pred, boolean emptyAttributesAreRestricted) {
		this.unresolvedPred = pred;
		this.emptyAttributesAreRestricted = emptyAttributesAreRestricted;
	}

	static PredicateResolver create(CqnPredicate pred, boolean emptyAttributesAreRestricted) {
		if (pred == null) {
			pred = UNRESTRICTED_PRED;
		}
		return new PredicateResolver(pred, emptyAttributesAreRestricted);
	}

	CqnPredicate resolve(UserInfo user) {
		// At this point of time, only $user attributes need to be resolved
		// TODO: pre-calculate if UserAttributeSubstitutor is required at all and return directly unresolvedPred otherwise
		return CQL.copy(unresolvedPred, new UserAttributeSubstitutor(user, emptyAttributesAreRestricted));
	}

	private static class UserAttributeSubstitutor implements Modifier {

		private static final String $TENANT = "$tenant";
		// Reserve $user.id for unique user id. In Node it already means $user.

		private static List<CqnValue> UNRESTRICTED_VALUE = Arrays.asList();
		private static List<CqnValue> RESTRICTED_VALUE = Arrays.asList();

		private static List<CqnValue> UNRESOLVED_EMPTY_VALUE = Arrays.asList();

		private final boolean emptyAttributesAreRestricted;
		private final UserInfo userInfo;

		public UserAttributeSubstitutor(UserInfo userInfo, boolean emptyAttributesAreRestricted) {
			this.userInfo = userInfo;
			this.emptyAttributesAreRestricted = emptyAttributesAreRestricted;
		}

		private List<CqnValue> ensureNormalizedValue(List<CqnValue> value) {
			if (value == UNRESOLVED_EMPTY_VALUE) {
				return emptyAttributesAreRestricted ? RESTRICTED_VALUE : UNRESTRICTED_VALUE;
			}
			return value;
		}

		@Override
		public Predicate comparison(Value<?> lhs, Operator op, Value<?> rhs) {

			MutablePair<List<CqnValue>, Boolean> lhsValuesRaw = resolveUserValuesRaw(lhs);

			List<CqnValue> lhsValues = lhsValuesRaw.getLeft();
			boolean isUserAttribute = lhsValuesRaw.getRight();

			// special case for null values (rhs only, don't support '=' resp. '<>')
			if (isUserAttribute && rhs.isNullValue()) {
				if (op == Operator.IS) {
					// '$user.<attr> is null' is unrestricted iff $user.<attr> is empty
					return (lhsValues == UNRESOLVED_EMPTY_VALUE) ? UNRESTRICTED_PRED : RESTRICTED_PRED;
				} else if (op == Operator.IS_NOT) {
					// '$user.<attr> is not null' is unrestricted iff $user.<attr> is not empty
					return (lhsValues != UNRESOLVED_EMPTY_VALUE) ? UNRESTRICTED_PRED : RESTRICTED_PRED;
				}
			}

			// ensure normalized values (could be still empty value)
			lhsValues = ensureNormalizedValue(lhsValues);

			List<CqnValue> rhsValues = resolveUserValues(rhs);

			if (lhsValues == RESTRICTED_VALUE || rhsValues == RESTRICTED_VALUE) { // Forbidden is stronger than unrestricted!
				return RESTRICTED_PRED;
			} else if (lhsValues == UNRESTRICTED_VALUE || rhsValues == UNRESTRICTED_VALUE) {
				return UNRESTRICTED_PRED;
			} else {
				// 'or'ed cross product based on op
				ArrayList<Predicate> preds = new ArrayList<>(lhsValues.size() * rhsValues.size());
				for (CqnValue lhsVal : lhsValues) {
					for (CqnValue rhsVal : rhsValues) {
						preds.add( CQL.comparison(lhsVal, op, rhsVal) );
					}
				}
				return CQL.or(preds);
			}
		}

		@Override
		public Predicate in(Value<?> value, CqnValue valueSet) {
			List<CqnValue> lhsValues = resolveUserValues(value);
			if(lhsValues == RESTRICTED_VALUE) {
				return RESTRICTED_PRED;
			} else if (lhsValues == UNRESTRICTED_VALUE) {
				return UNRESTRICTED_PRED;
			}

			List<CqnValue> rhsValues;
			if(valueSet.isList()) {
				rhsValues = new ArrayList<>();
				for(CqnValue listValue : valueSet.asList().values().collect(Collectors.toList())) {
					List<CqnValue> listValues = resolveUserValues(listValue);
					if(listValues == RESTRICTED_VALUE) {
						return RESTRICTED_PRED;
					} else if (listValues == UNRESTRICTED_VALUE) {
						return UNRESTRICTED_PRED;
					} else {
						rhsValues.addAll(listValues);
					}
				}
			} else {
				rhsValues = resolveUserValues(valueSet);
				if(rhsValues == RESTRICTED_VALUE) {
					return RESTRICTED_PRED;
				} else if (rhsValues == UNRESTRICTED_VALUE) {
					return UNRESTRICTED_PRED;
				}
			}

			return CQL.or(lhsValues.stream().map(v -> CQL.in(v, rhsValues)).collect(Collectors.toList()));
		}

		// subselect
		@Override
		public CqnPredicate exists(Select<?> subQuery) {
			subQuery.where().ifPresent(w -> {
				subQuery.where(CQL.copy(w, this));
			});

			return Modifier.super.exists(subQuery);
		}

		// any/all/exists
		@Override
		public CqnPredicate match(CqnMatchPredicate match) {
			if (match.predicate().isPresent()) {
				Predicate pred = CQL.copy(match.predicate().get(), this);
				return CQL.match(match.ref(), pred, match.quantifier());
			}
			return match;
		}

		@Override
		public CqnValue function(String name, List<Value<?>> args, String cdsType) {
			for (int i = 0; i <args.size(); ++i) {
				List<CqnValue> resolvedValue = resolveUserValues(args.get(i));
				if (resolvedValue.size() != 1) {
					throw new MultipleAttributeValuesNotSupportedException(name, extractUserAttribute(args.get(i)));
				}
				args.set(i, (Value<?>) resolvedValue.get(0));
			}
			return Modifier.super.function(name, args, cdsType);
		}

		@Override
		public CqnPredicate booleanFunction(String name, List<Value<?>> args) {
			for (int i = 0; i <args.size(); ++i) {
				List<CqnValue> resolvedValue = resolveUserValues(args.get(i));
				if (resolvedValue.size() != 1) {
					throw new MultipleAttributeValuesNotSupportedException(name, extractUserAttribute(args.get(i)));
				}
				args.set(i, (Value<?>) resolvedValue.get(0));
			}
			return Modifier.super.booleanFunction(name, args);
		}

		// possible results are:
		// 1. 'val' itself in a single element list if it is not an $user attribute
		// 2. A list of values of a user attribute (single element list in case of $user, $user.tenant)
		// 3. RESTRICTED_VALUE indicating restricted value
		// 4. UNRESOLVED_EMPTY_VALUE indicating an empty value, not decided on the effective value yet
		private List<CqnValue> resolveUserValues(CqnValue val) {
			return ensureNormalizedValue( resolveUserValuesRaw(val).getLeft() );
		}

		private MutablePair<List<CqnValue>, Boolean> resolveUserValuesRaw(CqnValue val) {

			String userAttribute = extractUserAttribute(val);
			if (userAttribute == null) {
				return MutablePair.of(Arrays.asList(val), false); // no user attribute
			}

			List<CqnValue> result = null;

			switch(userAttribute) {
			case $USER: // $user
				result = !StringUtils.isEmpty(userInfo.getName()) ?
						Arrays.asList( CQL.constant(userInfo.getName()) ) : RESTRICTED_VALUE;
				break;

			case $TENANT: // $tenant or $user.tenant
				result =  !StringUtils.isEmpty(userInfo.getTenant()) ?
						Arrays.asList( CQL.constant(userInfo.getTenant()) ) : RESTRICTED_VALUE;
				break;

			default: // $user.<attribute>
				List<String> attributeValues = userInfo.getAttributeValues(userAttribute);
				if (attributeValues == null || attributeValues.isEmpty()) {
					result = UNRESOLVED_EMPTY_VALUE;
				} else {
					result = attributeValues.stream().map(CQL::constant).collect(Collectors.toList());
				}
			}
			return MutablePair.of(result, true);
		}

		private String extractUserAttribute(CqnValue val) {
			if (val.isRef()) {
				List<? extends Segment> segments = val.asRef().segments();
				if (!segments.isEmpty()) {
					String firstSegment = segments.get(0).id();
					if ($USER.equalsIgnoreCase(firstSegment)) {
						if (segments.size() == 1) { // $user
							return $USER;
						} else {
							String attribute = segments.get(1).id();
							if ("tenant".equalsIgnoreCase(attribute)) { // $user.tenant
								return $TENANT;
							} // $user.<attribute>
							return attribute;
						}
					} else if (segments.size() == 1 && $TENANT.equalsIgnoreCase(firstSegment)) {
						return $TENANT;
					}
				}
			}
			return null; // else: no user attribute
		}
	}

	public static class MultipleAttributeValuesNotSupportedException extends RuntimeException {

		private static final long serialVersionUID = 1L;

		private final String attributeName;
		private final String resourceName;

		MultipleAttributeValuesNotSupportedException(String resourceName, String attributeName) {
			this.resourceName = resourceName;
			this.attributeName = attributeName;
		}

		public String getAttributeName() {
			return attributeName;
		}

		public String getResourceName() {
			return resourceName;
		}
	}

}
