/**************************************************************************
 * (C) 2019-2021 SAP SE or an SAP affiliate company. All rights reserved. *
 **************************************************************************/
package com.sap.cds.services.impl.authorization;

import java.util.Collections;
import java.util.List;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

import com.sap.cds.impl.parser.ExpressionParser;
import com.sap.cds.ql.CQL;
import com.sap.cds.ql.Predicate;
import com.sap.cds.ql.Select;
import com.sap.cds.ql.Value;
import com.sap.cds.ql.cqn.CqnComparisonPredicate.Operator;
import com.sap.cds.ql.cqn.CqnModifier;
import com.sap.cds.ql.cqn.CqnPredicate;
import com.sap.cds.ql.cqn.CqnReference.Segment;
import com.sap.cds.ql.cqn.CqnValue;
import com.sap.cds.services.request.UserInfo;
import com.sap.cds.services.utils.StringUtils;

public class PredicateResolver {
	// private static final Predicate FORBIDDEN_PRED = literal(1).eq(0);
	private static final Predicate UNRESTRICTED_PRED = CQL.constant(1).eq(CQL.constant(1));
	@Deprecated // TODO: change to FORBIDDEN_PRED AFTER UAA has introduces $unrestricted
	private static final Predicate UNRESTRICTED_DEPRECATED = UNRESTRICTED_PRED;

	private final CqnPredicate pred;

	private PredicateResolver(CqnPredicate pred) {
		this.pred = pred;
	}

	static PredicateResolver create(String cxnExpression) {
		CqnPredicate pred = StringUtils.isEmpty(cxnExpression) ? null : ExpressionParser.parsePredicate(cxnExpression);

		return new PredicateResolver(pred);
	}

	CqnPredicate resolve(UserInfo user) {
		if (pred == null) {
			return null;
		}
		return CQL.copy(pred, new UserInfoAttributeSubstitutor(user));
	}

	private static class UserInfoAttributeSubstitutor implements CqnModifier {
		private static final String USER = "$user";
		private static final String TENANT = "tenant";
		private final UserInfo userInfo;

		public UserInfoAttributeSubstitutor(UserInfo userInfo) {
			this.userInfo = userInfo;
		}

		@Override
		public Predicate comparison(Value<?> lhs, Operator op, Value<?> rhs) {
			if (isUserAttribute(lhs)) {
				return resolveUserAttribute(lhs, rhs, (user, val) -> CQL.comparison(user, op, val));
			} else if (isUserAttribute(rhs)) {
				return resolveUserAttribute(rhs, lhs, (user, val) -> CQL.comparison(val, op, user));
			}
			return CQL.comparison(lhs, op, rhs);
		}

		@Override
		public Predicate exists(Select<?> subQuery) {
			subQuery.where().ifPresent(w -> {
				subQuery.where(CQL.copy(w, this));
			});

			return CqnModifier.super.exists(subQuery);
		}

		private static boolean isUserAttribute(CqnValue rhs) {
			return rhs.isRef() && USER.equalsIgnoreCase(rhs.asRef().firstSegment());
		}

		private Predicate resolveUserAttribute(CqnValue userAttribute, CqnValue val,
				BiFunction<CqnValue, CqnValue, Predicate> comparison) {
			List<? extends Segment> segments = userAttribute.asRef().segments();
			if (segments.size() == 1) {
				return compare(userInfo.getName(), val, comparison);
			}
			String attribute = segments.get(1).id();
			if (TENANT.equalsIgnoreCase(attribute)) {
				return compare(userInfo.getTenant(), val, comparison);
			} else if (userInfo.isUnrestrictedAttribute(attribute)) {
				return UNRESTRICTED_PRED;
			}
			List<String> attributeValues = userInfo.getAttributeValues(attribute);
			List<Predicate> preds = attributeValues == null ? Collections.emptyList() :
				attributeValues.stream().map(v -> comparison.apply(CQL.val(v), val)).collect(Collectors.toList());

			if (preds.isEmpty()) {
				return UNRESTRICTED_DEPRECATED;
			}
			return CQL.or(preds);
		}

		private static Predicate compare(String attribute, CqnValue val, BiFunction<CqnValue, CqnValue, Predicate> comparison) {
			if (attribute == null) {
				return UNRESTRICTED_DEPRECATED;
			}
			return comparison.apply(CQL.val(attribute), val);
		}

	}

}
