/*
 * © 2025 SAP SE or an SAP affiliate company. All rights reserved.
 */
package com.sap.cds.services.impl.cds;

import static com.sap.cds.services.impl.cds.ValidationErrorHandler.THROW_IF_ERROR_IN_AFTER;
import static com.sap.cds.services.utils.model.CdsAnnotations.ASSERT_CONSTRAINT;

import com.sap.cds.Result;
import com.sap.cds.Row;
import com.sap.cds.impl.DataProcessor;
import com.sap.cds.impl.parser.ExprParser;
import com.sap.cds.ql.CQL;
import com.sap.cds.ql.Predicate;
import com.sap.cds.ql.Select;
import com.sap.cds.ql.Selectable;
import com.sap.cds.ql.Value;
import com.sap.cds.ql.cqn.CqnElementRef;
import com.sap.cds.ql.cqn.CqnPredicate;
import com.sap.cds.ql.cqn.CqnSelect;
import com.sap.cds.ql.cqn.CqnSelectListValue;
import com.sap.cds.ql.cqn.CqnStatement;
import com.sap.cds.ql.cqn.CqnValue;
import com.sap.cds.ql.cqn.Path;
import com.sap.cds.ql.impl.PathImpl;
import com.sap.cds.reflect.CdsAnnotation;
import com.sap.cds.reflect.CdsElement;
import com.sap.cds.reflect.CdsEntity;
import com.sap.cds.reflect.CdsModel;
import com.sap.cds.reflect.CdsSimpleType;
import com.sap.cds.reflect.CdsStructuredType;
import com.sap.cds.services.EventContext;
import com.sap.cds.services.cds.ApplicationService;
import com.sap.cds.services.cds.CqnService;
import com.sap.cds.services.draft.DraftService;
import com.sap.cds.services.handler.EventHandler;
import com.sap.cds.services.handler.annotations.After;
import com.sap.cds.services.handler.annotations.HandlerOrder;
import com.sap.cds.services.handler.annotations.ServiceName;
import com.sap.cds.services.impl.utils.ValidatorErrorUtils;
import com.sap.cds.services.impl.utils.ValidatorExecutor;
import com.sap.cds.services.messages.Message;
import com.sap.cds.services.messages.MessageTarget;
import com.sap.cds.services.messages.Messages;
import com.sap.cds.services.persistence.PersistenceService;
import com.sap.cds.services.runtime.CdsRuntime;
import com.sap.cds.services.utils.CdsErrorStatuses;
import com.sap.cds.services.utils.ErrorStatusException;
import com.sap.cds.services.utils.OrderConstants;
import com.sap.cds.services.utils.TenantAwareCache;
import com.sap.cds.services.utils.model.CdsAnnotations;
import com.sap.cds.util.CdsModelUtils;
import com.sap.cds.util.DataUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;

@ServiceName(value = "*", type = ApplicationService.class)
public class ConstraintAssertionHandler implements EventHandler {

  private static final String CONSTRAINT_PREFIX = "@constraint_";
  private static final String MANDATORY_PREFIX = "@mandatory_";
  private static final String INAPPLICABLE_PREFIX = "@inapplicable_";
  private static final String PARAMETER_PREFIX = "@parameter_";
  private static final Value<String> EMPTY = CQL.constant("");

  private final PersistenceService db;
  private final TenantAwareCache<Map<CdsStructuredType, Map<String, Constraint>>, CdsModel>
      constraintCache;

  ConstraintAssertionHandler(CdsRuntime runtime) {
    this.db =
        runtime
            .getServiceCatalog()
            .getService(PersistenceService.class, PersistenceService.DEFAULT_NAME);
    this.constraintCache = TenantAwareCache.create(ConcurrentHashMap::new, runtime);
  }

  @After(
      event = {
        CqnService.EVENT_CREATE,
        CqnService.EVENT_UPDATE,
        CqnService.EVENT_UPSERT,
        DraftService.EVENT_DRAFT_NEW,
        DraftService.EVENT_DRAFT_PATCH
      })
  @HandlerOrder(OrderConstants.After.CHECK_CONSTRAINTS)
  public void assertConstraints(Result result, CqnStatement cqn, EventContext context) {
    if (ValidatorExecutor.isValidationEvent(context, true)) {
      assertContraintsInternal(result, context);
    }
  }

  private void assertContraintsInternal(Result result, EventContext context) {
    // due to annotation flattening in the compiler, we need to reconstruct the annotation structure
    // use the CdsDataProcessor to determine all assert.constraints on the entities and elements
    Map<CdsStructuredType, Map<String, Constraint>> constraintsPerEntity =
        constraintCache.findOrCreate();
    Map<CdsStructuredType, List<Map<String, Object>>> keysPerEntity = new HashMap<>();
    CdsStructuredType resultType =
        result.rowType() != null ? result.rowType() : context.getTarget();

    // TODO handle structured types
    DataProcessor.create()
        .bulkAction(
            (type, entries) -> {
              Map<String, Constraint> constraints =
                  constraintsPerEntity.computeIfAbsent(
                      type, ConstraintAssertionHandler::collectConstraints);

              // if no constraints found, there is no need to collect keys
              if (!constraints.isEmpty()) {
                List<Map<String, Object>> keyValues =
                    keysPerEntity.computeIfAbsent(type, (t) -> new ArrayList<>());
                Set<String> keyNames = CdsModelUtils.concreteKeyNames(type);
                entries.forEach(entry -> addKeysTo(keyNames, entry, keyValues));
              }
            })
        .process(result, resultType);
    if (keysPerEntity.isEmpty()) return;

    // run through entity keys and execute their constraints
    Map<CdsStructuredType, Map<Map<String, Object>, ViolatedConstraints>>
        erroneousConstraintsPerEntity = new HashMap<>(keysPerEntity.size());
    keysPerEntity.forEach(
        (entity, keyValues) -> {
          if (keyValues.isEmpty()) return;

          // put constraints conditions and their message parameters on select list
          Map<String, Constraint> constraints = constraintsPerEntity.get(entity);
          List<Selectable> selectables = new ArrayList<>(constraints.size());
          constraints
              .values()
              .forEach(
                  c -> {
                    selectables.add(c.slv);
                    selectables.addAll(c.dynamicParameters);
                  });

          // keys need to be part of the select list to map constraint results to entity instances
          List<String> keyElements = new ArrayList<>(keyValues.get(0).keySet());
          keyElements.forEach(keyElement -> selectables.add(CQL.get(keyElement)));

          // execute a single query, which performs all checks at once on the DB using an `IN`
          // operator for the given entity
          CqnSelect query =
              Select.from(entity.as(CdsEntity.class))
                  .columns(selectables)
                  .where(CQL.in(keyElements, keyValues));
          Result validationResult = db.run(query);

          // collect the entity keys and erroneous constraints
          validationResult.forEach(
              row -> {
                Map<String, Object> keys = new HashMap<>(keyElements.size());
                List<Constraint> erroneousConstraints = new ArrayList<>(constraints.size());

                // go through the result to collect keys and erroneous constraints
                row.forEach(
                    (k, v) -> {
                      if (k.startsWith(CONSTRAINT_PREFIX) && Boolean.TRUE.equals(v)) {
                        erroneousConstraints.add(
                            constraints.get(k.substring(CONSTRAINT_PREFIX.length())));
                      } else if (keyElements.contains(k)) {
                        keys.put(k, v);
                      }
                    });

                if (!erroneousConstraints.isEmpty()) {
                  erroneousConstraintsPerEntity
                      .computeIfAbsent(entity, (e) -> new HashMap<>(keyValues.size()))
                      .put(keys, new ViolatedConstraints(erroneousConstraints, row));
                }
              });
        });

    // run through the data again and process the found errors to create the messages
    if (erroneousConstraintsPerEntity.isEmpty()) return;
    DataProcessor.create()
        .action(
            new DataProcessor.Action() {
              @Override
              public void entries(
                  Path path,
                  CdsElement element,
                  CdsStructuredType type,
                  Iterable<Map<String, Object>> entries) {
                Map<Map<String, Object>, ViolatedConstraints> erroneousConstraints =
                    erroneousConstraintsPerEntity.get(type);
                if (erroneousConstraints == null) return;

                Set<String> keyNames = CdsModelUtils.concreteKeyNames(type);
                entries.forEach(
                    entry -> {
                      Map<String, Object> keyValues =
                          keyNames.stream()
                              .collect(Collectors.toMap(keyName -> keyName, entry::get));
                      ViolatedConstraints constraints = erroneousConstraints.get(keyValues);

                      if (constraints != null) {
                        Path enhancedPath = ((PathImpl) path).append(element, type, entry);
                        constraints.erroneous.forEach(
                            constraint -> {
                              Set<String> targetPaths =
                                  constraint.targets.stream()
                                      .map(s -> s.path())
                                      .collect(Collectors.toSet());
                              // report all errors for non-draft events
                              // report errors for targets only for draft events
                              // the later makes sure it nicely plays together with draft messages
                              if (!context.getEvent().startsWith("DRAFT_")
                                  || targetPaths.stream()
                                      .anyMatch(p -> DataUtils.containsKey(entry, p))) {
                                // static parameters are preferred
                                final Object[] effectiveParameters;
                                if (constraint.parameters.length == 0
                                    && !constraint.dynamicParameters.isEmpty()) {
                                  effectiveParameters =
                                      new Object[constraint.dynamicParameters.size()];
                                  for (int i = 0; i < effectiveParameters.length; ++i) {
                                    effectiveParameters[i] =
                                        constraints.row.get(
                                            PARAMETER_PREFIX + constraint.name + "_" + i);
                                  }
                                } else {
                                  effectiveParameters = constraint.parameters;
                                }
                                handleError(
                                    context.getMessages(),
                                    constraint,
                                    effectiveParameters,
                                    enhancedPath);
                              }
                            });
                      }
                    });
              }
            })
        .process(result, resultType);

    if (!context.getEvent().startsWith("DRAFT_")) {
      // ensure exception is thrown, in case errors are collected
      context.put(THROW_IF_ERROR_IN_AFTER, true);
    }
  }

  private record ViolatedConstraints(List<Constraint> erroneous, Row row) {}

  private static Map<String, Constraint> collectConstraints(CdsStructuredType type) {
    Map<String, Constraint> constraints = new HashMap<>();
    constraints.putAll(
        collectAssertConstraintAnnotations(type.annotations(), null, type.getName()));
    type.elements()
        .forEach(
            e -> {
              constraints.putAll(
                  collectAssertConstraintAnnotations(e.annotations(), e, type.getName()));
              // handle dynamic @mandatory annotations
              constraints.putAll(collectMandatoryAnnotations(e, type));
              // handle dynamic @inapplicable annotations
              constraints.putAll(collectInapplicableAnnotations(e, type));
            });
    return constraints.isEmpty() ? Map.of() : Collections.unmodifiableMap(constraints);
  }

  @SuppressWarnings("unchecked")
  private static Map<String, Constraint> collectAssertConstraintAnnotations(
      Stream<CdsAnnotation<?>> annotations, CdsElement annotatedElement, String entityName) {
    Map<String, Constraint> constraints = new HashMap<>();
    annotations
        .filter(a -> a.getName().startsWith("assert.constraint."))
        .forEach(
            a -> {
              String[] segments = a.getName().substring(18).split("\\.");
              if (segments.length < 2) {
                if (annotatedElement == null) {
                  throw new ErrorStatusException(
                      CdsErrorStatuses.INVALID_ANNOTATION_ENTITY, ASSERT_CONSTRAINT, entityName);
                } else {
                  throw new ErrorStatusException(
                      CdsErrorStatuses.INVALID_ANNOTATION,
                      ASSERT_CONSTRAINT,
                      annotatedElement.getName(),
                      entityName);
                }
              }

              Constraint constraint =
                  constraints.computeIfAbsent(
                      segments[0], name -> new Constraint(name, annotatedElement));
              // TODO error code
              switch (segments[1]) {
                case "condition":
                  constraint.setCondition(
                      new ExprParser().parsePredicate(((CqnValue) a.getValue()).tokens()));
                  break;
                case "message":
                  constraint.setMessage((String) a.getValue());
                  break;
                case "targets":
                  constraint.setTargets((List<CqnElementRef>) a.getValue());
                  break;
                case "parameters":
                  constraint.setDynamicParameters((List<Value<?>>) a.getValue());
                  break;
              }
            });
    return constraints;
  }

  private static Map<String, Constraint> collectInapplicableAnnotations(
      CdsElement e, CdsStructuredType type) {
    if (CdsAnnotations.INAPPLICABLE.isExpression(e)) {
      String constraintName = INAPPLICABLE_PREFIX + e.getName();
      Constraint c = new Constraint(constraintName, e);
      c.setMessage(
          CdsAnnotations.INAPPLICABLE_MESSAGE,
          CdsErrorStatuses.VALUE_NOT_APPLICABLE,
          e.getName(),
          type.getQualifiedName());
      Predicate isNotNull = CQL.get(e.getName()).isNotNull();
      if (e.getType() instanceof CdsSimpleType simple
          && String.class.isAssignableFrom(simple.getJavaType())) {
        isNotNull = isNotNull.or(CQL.get(e.getName()).trim().ne(EMPTY));
      }
      c.setCondition(isNotNull.and(CdsAnnotations.INAPPLICABLE.asPredicate(e)).not());
      return Map.of(constraintName, c);
    }
    return Map.of();
  }

  private static Map<String, Constraint> collectMandatoryAnnotations(
      CdsElement e, CdsStructuredType type) {
    if (CdsAnnotations.MANDATORY.isExpression(e)) {
      String constraintName = MANDATORY_PREFIX + e.getName();
      Constraint c = new Constraint(constraintName, e);
      c.setMessage(
          CdsAnnotations.MANDATORY_MESSAGE,
          CdsErrorStatuses.VALUE_REQUIRED,
          e.getName(),
          type.getQualifiedName());
      Predicate isNull = CQL.get(e.getName()).isNull();
      if (e.getType() instanceof CdsSimpleType simple
          && String.class.isAssignableFrom(simple.getJavaType())) {
        isNull = isNull.or(CQL.get(e.getName()).trim().eq(EMPTY));
      }
      c.setCondition(isNull.and(CdsAnnotations.MANDATORY.asPredicate(e)).not());
      return Map.of(constraintName, c);
    }
    return Map.of();
  }

  // also used by ReadOnlyHandler
  static void addKeysTo(
      Set<String> keyNames, Map<String, Object> entry, List<Map<String, Object>> keyValues) {
    Map<String, Object> keys = new HashMap<>(keyNames.size());
    for (String key : keyNames) {
      if (entry.containsKey(key)) {
        keys.put(key, entry.get(key));
      }
    }
    // only add complete key sets
    if (keys.size() == keyNames.size()) {
      keyValues.add(keys);
    }
  }

  private static void handleError(
      Messages messages, Constraint c, Object[] effectiveParameters, Path path) {
    Message message = messages.error(c.message, effectiveParameters);
    if (c.code != null) {
      message.code(c.code);
    }

    if (!c.targets.isEmpty()) {
      List<MessageTarget> targets =
          c.targets.stream()
              .map(t -> MessageTarget.create(path, CdsModelUtils.element(path.target().type(), t)))
              .collect(Collectors.toList());

      message.target(targets.remove(0));
      if (!targets.isEmpty()) {
        message.additionalTargets(targets);
      }
    }
  }

  private static class Constraint {

    private final String name;
    private final CdsElement element;
    private CqnSelectListValue slv;
    private String message;
    private String code;
    private List<CqnElementRef> targets = List.of();
    private Object[] parameters;
    private List<CqnSelectListValue> dynamicParameters = List.of();

    public Constraint(String name, CdsElement element) {
      this.name = name;
      this.element = element;
      if (element != null) {
        setTargets(List.of(CQL.get(element.getName())));
      }
      setMessage(null, CdsErrorStatuses.CONSTRAINT_VIOLATED, name);
    }

    public void setCondition(CqnPredicate predicate) {
      // wrap the condition in a not to have simpler handling of null values
      this.slv = CQL.not(predicate).as(CONSTRAINT_PREFIX + name);
    }

    public void setMessage(String message) {
      this.message = message;
      this.code = null;
      this.parameters = new Object[0];
    }

    public void setMessage(
        CdsAnnotations messageAnnotation, CdsErrorStatuses errorStatus, Object... parameters) {
      this.message = ValidatorErrorUtils.getMessageKey(messageAnnotation, element, errorStatus);
      this.code = errorStatus.getCodeString();
      this.parameters = parameters;
    }

    public void setTargets(List<CqnElementRef> targets) {
      this.targets = targets;
    }

    public void setDynamicParameters(List<Value<?>> parameters) {
      this.dynamicParameters = new ArrayList<>(parameters.size());
      for (int i = 0; i < parameters.size(); ++i) {
        this.dynamicParameters.add(parameters.get(i).as(PARAMETER_PREFIX + name + "_" + i));
      }
    }
  }
}
