/*
 * © 2020-2024 SAP SE or an SAP affiliate company. All rights reserved.
 */
package com.sap.cds.services.impl.authorization;

import static com.sap.cds.services.utils.model.CdsModelUtils.$USER;

import com.sap.cds.ql.CQL;
import com.sap.cds.ql.Predicate;
import com.sap.cds.ql.Select;
import com.sap.cds.ql.Value;
import com.sap.cds.ql.cqn.CqnComparisonPredicate.Operator;
import com.sap.cds.ql.cqn.CqnInSubquery;
import com.sap.cds.ql.cqn.CqnMatchPredicate;
import com.sap.cds.ql.cqn.CqnPredicate;
import com.sap.cds.ql.cqn.CqnReference.Segment;
import com.sap.cds.ql.cqn.CqnValue;
import com.sap.cds.ql.cqn.Modifier;
import com.sap.cds.services.request.UserInfo;
import com.sap.cds.services.utils.StringUtils;
import java.util.List;
import org.apache.commons.lang3.tuple.MutablePair;

public class PredicateResolver {

  // Handling of empty (or undefined) attributes:
  // - a user w/o role assignment will also have an empty attribute list
  // - an IdP-mapped attribute might be empty even if 'valueRequired:true' has been specified for
  // the attribute

  // Also note that 'valueRequired:true' is the default for attributes.
  // This means: potentially unrestricted attributes need to be configured explicitly.
  // 'valueRequired:true' can be modeled by '$user.attr <op> value'
  // 'valueRequired:false' can be modeled by '$user.attr <op> value or $user.attr is null'

  // The cached unresolved, but parsed predicate from the model
  private final CqnPredicate unresolvedPred;

  private PredicateResolver(CqnPredicate pred) {
    this.unresolvedPred = pred;
  }

  static PredicateResolver create(CqnPredicate pred) {
    if (pred == null) {
      pred = CQL.TRUE;
    }
    return new PredicateResolver(pred);
  }

  CqnPredicate resolve(UserInfo user) {
    // At this point of time, only $user attributes need to be resolved
    // TODO: pre-calculate if UserAttributeSubstitutor is required at all and return directly
    // unresolvedPred otherwise
    return CQL.copy(unresolvedPred, new UserAttributeSubstitutor(user));
  }

  private static class UserAttributeSubstitutor implements Modifier {

    private static final String $TENANT = "$tenant";
    // Reserve $user.id for unique user id. In Node it already means $user.

    private final UserInfo userInfo;

    public UserAttributeSubstitutor(UserInfo userInfo) {
      this.userInfo = userInfo;
    }

    @Override
    public CqnPredicate comparison(Value<?> lhs, Operator op, Value<?> rhs) {

      MutablePair<List<CqnValue>, Boolean> lhsValuesRaw = resolveUserValuesRaw(lhs);

      List<CqnValue> lhsValues = lhsValuesRaw.getLeft();
      boolean isUserAttribute = lhsValuesRaw.getRight();

      // special case for null values (rhs only, don't support '=' resp. '<>')
      if (isUserAttribute && rhs.isNullValue()) {
        if (op == Operator.IS) {
          // '$user.<attr> is null' is true if $user.<attr> is empty
          return (lhsValues.isEmpty()) ? CQL.TRUE : CQL.FALSE;
        } else if (op == Operator.IS_NOT) {
          // '$user.<attr> is not null' is true if $user.<attr> is not empty
          return (!lhsValues.isEmpty()) ? CQL.TRUE : CQL.FALSE;
        }
      }

      List<CqnValue> rhsValues = resolveUserValues(rhs);

      // 'or'ed cross product based on op
      return lhsValues.stream()
          .flatMap(lhsVal -> rhsValues.stream().map(rhsVal -> CQL.comparison(lhsVal, op, rhsVal)))
          .collect(CQL.withOr());
    }

    @Override
    public CqnPredicate in(Value<?> value, CqnValue valueSet) {
      List<CqnValue> lhsValues = resolveUserValues(value);
      List<CqnValue> rhsValues;
      if (valueSet.isList()) {
        rhsValues = valueSet.asList().values().flatMap(v -> resolveUserValues(v).stream()).toList();
      } else {
        rhsValues = resolveUserValues(valueSet);
      }

      return lhsValues.stream().map(v -> CQL.in(v, rhsValues)).collect(CQL.withOr());
    }

    // where exists subquery
    @Override
    public CqnPredicate exists(Select<?> subQuery) {
      subQuery
          .where()
          .ifPresent(
              w -> {
                subQuery.where(CQL.copy(w, this));
              });

      return Modifier.super.exists(subQuery);
    }

    // in subquery
    @Override
    public CqnPredicate in(CqnInSubquery inSubquery) {
      throw new UnsupportedOperationException("Unsupported predicate: " + inSubquery);
    }

    // any/all/exists
    @Override
    public CqnPredicate match(CqnMatchPredicate match) {
      if (match.predicate().isPresent()) {
        Predicate pred = CQL.copy(match.predicate().get(), this);
        return CQL.match(match.ref(), pred, match.quantifier());
      }
      return match;
    }

    @Override
    public CqnValue function(String name, List<Value<?>> args, String cdsType) {
      for (int i = 0; i < args.size(); ++i) {
        List<CqnValue> resolvedValue = resolveUserValues(args.get(i));
        if (resolvedValue.size() != 1) {
          throw new MultipleAttributeValuesNotSupportedException(
              name, extractUserAttribute(args.get(i)));
        }
        args.set(i, (Value<?>) resolvedValue.get(0));
      }
      return Modifier.super.function(name, args, cdsType);
    }

    @Override
    public CqnPredicate booleanFunction(String name, List<Value<?>> args) {
      for (int i = 0; i < args.size(); ++i) {
        List<CqnValue> resolvedValue = resolveUserValues(args.get(i));
        if (resolvedValue.size() != 1) {
          throw new MultipleAttributeValuesNotSupportedException(
              name, extractUserAttribute(args.get(i)));
        }
        args.set(i, (Value<?>) resolvedValue.get(0));
      }
      return Modifier.super.booleanFunction(name, args);
    }

    // possible results are:
    // 1. 'val' itself in a single element list if it is not an $user attribute
    // 2. A list of values of a user attribute (single element list in case of $user, $user.tenant)
    // 3. An empty list in case the user attribute has not been set or didn't have any values
    private List<CqnValue> resolveUserValues(CqnValue val) {
      return resolveUserValuesRaw(val).getLeft();
    }

    private MutablePair<List<CqnValue>, Boolean> resolveUserValuesRaw(CqnValue val) {

      String userAttribute = extractUserAttribute(val);
      if (userAttribute == null) {
        return MutablePair.of(List.of(val), false); // no user attribute
      }

      List<CqnValue> result = null;

      switch (userAttribute) {
        case $USER: // $user
          result =
              !StringUtils.isEmpty(userInfo.getName())
                  ? List.of(CQL.constant(userInfo.getName()))
                  : List.of();
          break;

        case $TENANT: // $tenant or $user.tenant
          result =
              !StringUtils.isEmpty(userInfo.getTenant())
                  ? List.of(CQL.constant(userInfo.getTenant()))
                  : List.of();
          break;

        default: // $user.<attribute>
          List<String> attributeValues = userInfo.getAttributeValues(userAttribute);
          if (attributeValues == null || attributeValues.isEmpty()) {
            result = List.of();
          } else {
            result = attributeValues.stream().map(v -> (CqnValue) CQL.constant(v)).toList();
          }
      }
      return MutablePair.of(result, true);
    }

    private String extractUserAttribute(CqnValue val) {
      if (val.isRef()) {
        List<? extends Segment> segments = val.asRef().segments();
        if (!segments.isEmpty()) {
          String firstSegment = segments.get(0).id();
          if ($USER.equalsIgnoreCase(firstSegment)) {
            if (segments.size() == 1) { // $user
              return $USER;
            } else {
              String attribute = segments.get(1).id();
              if ("tenant".equalsIgnoreCase(attribute)) { // $user.tenant
                return $TENANT;
              } // $user.<attribute>
              return attribute;
            }
          } else if (segments.size() == 1 && $TENANT.equalsIgnoreCase(firstSegment)) {
            return $TENANT;
          }
        }
      }
      return null; // else: no user attribute
    }
  }

  public static class MultipleAttributeValuesNotSupportedException extends RuntimeException {

    private static final long serialVersionUID = 1L;

    private final String attributeName;
    private final String resourceName;

    MultipleAttributeValuesNotSupportedException(String resourceName, String attributeName) {
      this.resourceName = resourceName;
      this.attributeName = attributeName;
    }

    public String getAttributeName() {
      return attributeName;
    }

    public String getResourceName() {
      return resourceName;
    }
  }
}
