/**************************************************************************
 * (C) 2019-2021 SAP SE or an SAP affiliate company. All rights reserved. *
 **************************************************************************/
package com.sap.cds.services.impl.authorization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

import com.sap.cds.impl.parser.ExpressionParser;
import com.sap.cds.ql.CQL;
import com.sap.cds.ql.Predicate;
import com.sap.cds.ql.Select;
import com.sap.cds.ql.StructuredTypeRef;
import com.sap.cds.ql.Value;
import com.sap.cds.ql.cqn.CqnComparisonPredicate.Operator;
import com.sap.cds.ql.cqn.CqnMatchPredicate.Quantifier;
import com.sap.cds.ql.cqn.CqnPredicate;
import com.sap.cds.ql.cqn.CqnReference.Segment;
import com.sap.cds.ql.cqn.CqnValue;
import com.sap.cds.ql.cqn.Modifier;
import com.sap.cds.services.request.UserInfo;
import com.sap.cds.services.utils.StringUtils;

public class PredicateResolver {

	static final Predicate UNRESTRICTED_PRED = CQL.constant(1).eq(CQL.constant(1));
	static final Predicate FORBIDDEN_PRED = CQL.constant(1).eq(CQL.constant(0));

	// Assuming former behaviour as default ($UNRESTRICTED is not set): empty attribute list means fully unrestricted.
	// Unfortunately, a user w/o role assignment will also have an empty attribute list...
	// TODO: in future xsuaa versions, xsuaa binding will tell if $UNRESTRICTED in used to code unrestricted instead or not.
	static final boolean TREAT_EMPTY_ATTRIBUTES_AS_RESTRICTED = false;

	// The cached unresolved, but parsed predicate from the model
	private final CqnPredicate unresolvedPred;

	private PredicateResolver(CqnPredicate pred) {
		this.unresolvedPred = pred;
	}

	static PredicateResolver create(String cxnExpression) {
		CqnPredicate pred = !StringUtils.isEmpty(cxnExpression) ? ExpressionParser.parsePredicate(cxnExpression) : UNRESTRICTED_PRED;
		return new PredicateResolver(pred);
	}

	CqnPredicate resolve(UserInfo user) {
		// At this point of time, only $user attributes need to be resolved
		// TODO: pre-calculate if UserAttributeSubstitutor is required at all and return directly unresolvedPred otherwise
		return CQL.copy(unresolvedPred, new UserAttributeSubstitutor(user));
	}

	private static class UserAttributeSubstitutor implements Modifier {

		private static final String USER = "$user";
		private static final String TENANT = "tenant";
		// Reserve $user.id for unique user id. In Node it already means $user.

		private static List<CqnValue> UNRESTRICTED_VALUE = null;
		private static List<CqnValue> FORBIDDEN_VALUE = Collections.emptyList();

		private final UserInfo userInfo;

		public UserAttributeSubstitutor(UserInfo userInfo) {
			this.userInfo = userInfo;
		}

		@Override
		public Predicate comparison(Value<?> lhs, Operator op, Value<?> rhs) {

			List<CqnValue> lhsValues = resolveUserValues(lhs);
			List<CqnValue> rhsValues = resolveUserValues(rhs);
			if (lhsValues == FORBIDDEN_VALUE || rhsValues == FORBIDDEN_VALUE) { // Forbidden is stronger than unrestricted!
				return FORBIDDEN_PRED;
			} else if (lhsValues == UNRESTRICTED_VALUE || rhsValues == UNRESTRICTED_VALUE) {
				return UNRESTRICTED_PRED;
			} else {
				// 'or'ed cross product based on op
				ArrayList<Predicate> preds = new ArrayList<>(lhsValues.size() * rhsValues.size());
				for (CqnValue lhsVal : lhsValues) {
					for (CqnValue rhsVal : rhsValues) {
						preds.add( CQL.comparison(lhsVal, op, rhsVal) );
					}
				}
				return CQL.or(preds);
			}
		}

		@Override
		public Predicate in(Value<?> value, CqnValue valueSet) {
			List<CqnValue> lhsValues = resolveUserValues(value);
			if(lhsValues == FORBIDDEN_VALUE) {
				return FORBIDDEN_PRED;
			} else if (lhsValues == UNRESTRICTED_VALUE) {
				return UNRESTRICTED_PRED;
			}

			List<CqnValue> rhsValues;
			if(valueSet.isList()) {
				rhsValues = new ArrayList<>();
				for(CqnValue listValue : valueSet.asList().values().collect(Collectors.toList())) {
					List<CqnValue> listValues = resolveUserValues(listValue);
					if(listValues == FORBIDDEN_VALUE) {
						return FORBIDDEN_PRED;
					} else if (listValues == UNRESTRICTED_VALUE) {
						return UNRESTRICTED_PRED;
					} else {
						rhsValues.addAll(listValues);
					}
				}
			} else {
				rhsValues = resolveUserValues(valueSet);
				if(rhsValues == FORBIDDEN_VALUE) {
					return FORBIDDEN_PRED;
				} else if (rhsValues == UNRESTRICTED_VALUE) {
					return UNRESTRICTED_PRED;
				}
			}

			return CQL.or(lhsValues.stream().map(v -> CQL.in(v, rhsValues)).collect(Collectors.toList()));
		}

		// subselect
		@Override
		public Predicate exists(Select<?> subQuery) {
			subQuery.where().ifPresent(w -> {
				subQuery.where(CQL.copy(w, this));
			});

			return Modifier.super.exists(subQuery);
		}

		// any/all/exists
		@Override
		public Predicate match(StructuredTypeRef ref, Predicate pred, Quantifier quantifier) {
			return Modifier.super.match(ref, pred != null ? CQL.copy(pred, this) : pred, quantifier);
		}

		@Override
		public Value<?> function(String name, List<Value<?>> args, String cdsType) {
			for (int i = 0; i <args.size(); ++i) {
				List<CqnValue> resolvedValue = resolveUserValues(args.get(i));
				if (resolvedValue.size() != 1) {
					throw new MultipleAttributeValuesNotSupportedException(name, extractUserAttribute(args.get(i)));
				}
				args.set(i, (Value<?>) resolvedValue.get(0));
			}
			return Modifier.super.function(name, args, cdsType);
		}

		@Override
		public Predicate booleanFunction(String name, List<Value<?>> args) {
			for (int i = 0; i <args.size(); ++i) {
				List<CqnValue> resolvedValue = resolveUserValues(args.get(i));
				if (resolvedValue.size() != 1) {
					throw new MultipleAttributeValuesNotSupportedException(name, extractUserAttribute(args.get(i)));
				}
				args.set(i, (Value<?>) resolvedValue.get(0));
			}
			return Modifier.super.booleanFunction(name, args);
		}

		// possible results are:
		// 1. 'val' itself in a single element list if it is not an $user attribute
		// 2. A list of values of a user attribute (single element list in case of $user, $user.tenant)
		// 3. RESTRICTED_VALUE indicating restricted value
		// 4. UNRESTRICTED_VALUE indicating an unrestricted value
		private List<CqnValue> resolveUserValues(CqnValue val) {

			String userAttribute = extractUserAttribute(val);
			if (userAttribute == null) {
				return Arrays.asList(val); // no user attribute
			}

			switch(userAttribute) {
			case USER: // $user
				return !StringUtils.isEmpty(userInfo.getName()) ?
						Arrays.asList( CQL.constant(userInfo.getName()) ) : FORBIDDEN_VALUE;

			case TENANT: // $user.tenant
				return !StringUtils.isEmpty(userInfo.getTenant()) ?
						Arrays.asList( CQL.constant(userInfo.getTenant()) ) : FORBIDDEN_VALUE;

			default: // $user.<attribute>
				if (userInfo.isUnrestrictedAttribute(userAttribute)) { // $unrestricted is part of attribute's value list
					return UNRESTRICTED_VALUE;
				}
				List<String> attributeValues = userInfo.getAttributeValues(userAttribute);
				if (attributeValues == null || attributeValues.isEmpty()) {
					return TREAT_EMPTY_ATTRIBUTES_AS_RESTRICTED ? FORBIDDEN_VALUE : UNRESTRICTED_VALUE;
				} else {
					return attributeValues.stream().map(v -> CQL.constant(v)).collect(Collectors.toList());
				}
			}
		}

		private String extractUserAttribute(CqnValue val) {
			if (val.isRef()) {
				List<? extends Segment> segments = val.asRef().segments();
				if (!segments.isEmpty() && USER.equalsIgnoreCase(segments.get(0).id())) {
					if (segments.size() == 1) { // $user
						return USER;
					} else {
						String attribute = segments.get(1).id();
						if (TENANT.equalsIgnoreCase(attribute)) { // $user.tenant
							return TENANT;
						} else { // $user.<attribute>
							return attribute;
						}
					}
				}
			}
			return null; // else: no user attribute
		}
	}

	public static class MultipleAttributeValuesNotSupportedException extends RuntimeException {

		private static final long serialVersionUID = 1L;

		private final String attributeName;
		private final String resourceName;

		MultipleAttributeValuesNotSupportedException(String resourceName, String attributeName) {
			this.resourceName = resourceName;
			this.attributeName = attributeName;
		}

		public String getAttributeName() {
			return attributeName;
		}

		public String getResourceName() {
			return resourceName;
		}
	}
}
