/*
 * © 2024-2025 SAP SE or an SAP affiliate company. All rights reserved.
 */
package com.sap.cds.jdbc.hana.hierarchies;

import com.sap.cds.impl.builder.model.InSubquery;
import com.sap.cds.ql.CQL;
import com.sap.cds.ql.ElementRef;
import com.sap.cds.ql.Literal;
import com.sap.cds.ql.Predicate;
import com.sap.cds.ql.Select;
import com.sap.cds.ql.Value;
import com.sap.cds.ql.cqn.CqnPredicate;
import com.sap.cds.ql.cqn.CqnSelect;
import com.sap.cds.ql.cqn.CqnSelectListItem;
import com.sap.cds.ql.cqn.CqnSelectListValue;
import com.sap.cds.ql.cqn.CqnSortSpecification;
import com.sap.cds.ql.cqn.transformation.CqnAncestorsTransformation;
import com.sap.cds.ql.cqn.transformation.CqnDescendantsTransformation;
import com.sap.cds.ql.cqn.transformation.CqnFilterTransformation;
import com.sap.cds.ql.cqn.transformation.CqnHierarchySubsetTransformation;
import com.sap.cds.ql.cqn.transformation.CqnSearchTransformation;
import com.sap.cds.ql.cqn.transformation.CqnTopLevelsTransformation;
import com.sap.cds.ql.cqn.transformation.CqnTransformation;
import com.sap.cds.ql.cqn.transformation.CqnTransformationVisitor;
import com.sap.cds.ql.hana.HANA;
import com.sap.cds.ql.hana.Hierarchy;
import com.sap.cds.ql.hana.HierarchySubset;
import com.sap.cds.ql.impl.SelectBuilder;
import com.sap.cds.reflect.CdsBaseType;
import com.sap.cds.reflect.CdsElement;
import com.sap.cds.reflect.CdsEntity;
import com.sap.cds.reflect.CdsModel;
import com.sap.cds.reflect.CdsStructuredType;
import com.sap.cds.util.CqnStatementUtils;
import com.sap.cds.util.transformations.HierarchyUtils;
import com.sap.cds.util.transformations.TransformationToSelect;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Stream;

public class HanaHierarchyResolver extends TransformationToSelect {

  private static final long LEVELS_ALL = -1L;
  private static final Literal<Number> ZERO = CQL.constant(0);
  private static final Literal<Number> ONE = CQL.constant(1);

  // HANA elements
  private static final String PARENT_ID = "parent_id";
  private static final String NODE_ID = "node_id";
  private static final String HIERARCHY_TREE_SIZE = "hierarchy_tree_size";
  private static final String HIERARCHY_LEVEL = "hierarchy_level";
  private static final String HIERARCHY_RANK = "hierarchy_rank";

  // OData elements
  private static final String DISTANCE_FROM_ROOT = "DistanceFromRoot";
  private static final String LIMITED_DESCENDANT_COUNT = "LimitedDescendantCount";
  private static final String DESCENDANT_COUNT = "DescendantCount";
  private static final ElementRef<Number> DESCENDANT_COUNT_REF = CQL.get(DESCENDANT_COUNT);
  private static final String DRILL_STATE = "DrillState";
  private static final String RANK = "Rank";
  private static final String LIMITED_RANK = "LimitedRank";

  private boolean isHierarchicalSelect;

  // drill states
  private static final Literal<String> COLLAPSED = CQL.constant("collapsed");
  private static final Literal<String> EXPANDED = CQL.constant("expanded");
  private static final Literal<String> LEAF = CQL.constant("leaf");

  private final CdsModel model;
  private Hierarchy sourceHierarchy;
  private List<CqnTransformation> transformations;
  private List<CqnSelectListItem> originalItems;
  private String nodeId;
  private String parentId;

  public HanaHierarchyResolver(CdsModel model, Select<?> select) {
    super(select);
    this.model = model;
  }

  @Override
  protected void before(CqnSelect original) {
    originalItems = original.items();
    isHierarchicalSelect =
        original.transformations().stream().anyMatch(HierarchyUtils::isHierarchical);
    if (isHierarchicalSelect) {
      this.transformations = original.transformations();
      Select<?> hierarchySource =
          HierarchyUtils.addNodeAndParentElements(model, select, transformations);

      this.sourceHierarchy = HANA.hierarchy(hierarchySource).orderBy(CQL.get(NODE_ID).asc());
      select =
          Select.from(sourceHierarchy)
              .columns(CQL.star(), descendantCount(), distanceFromRoot(), rank())
              .excluding(HIERARCHY_TREE_SIZE, HIERARCHY_LEVEL, HIERARCHY_RANK)
              .hints(original.hints());

      CdsStructuredType sourceRowType = CqnStatementUtils.rowType(model, hierarchySource);
      nodeId = findCaseInsensitive(sourceRowType, NODE_ID);
      parentId = findCaseInsensitive(sourceRowType, PARENT_ID);

      // default order (by ID or sort specification from projection)
      CdsEntity targetEntity = (CdsEntity) CqnStatementUtils.rowType(model, original.from());
      applyDefaultOrder(targetEntity);
    }
  }

  private void applyDefaultOrder(CdsEntity target) {
    applyOrderBy(() -> List.of(CQL.get(nodeId).asc()));
    target.query().ifPresent(q -> applyOrderBy(() -> q.orderBy()));
  }

  private static String findCaseInsensitive(CdsStructuredType type, String name) {
    return type.elements()
        .map(CdsElement::getName)
        .filter(name::equalsIgnoreCase)
        .findFirst()
        .orElse(name);
  }

  @Override
  protected void after() {
    if (isHierarchicalSelect) {
      var iter = transformations.listIterator(transformations.size());
      while (iter.hasPrevious()) {
        if (iter.previous() instanceof CqnDescendantsTransformation d) {
          addDescendantsDrillState((SelectBuilder<?>) select, d);
          break;
        }
      }
    }
  }

  @Override
  protected void copySelectList(List<CqnSelectListItem> slis) {
    if (!isHierarchicalSelect) {
      super.copySelectList(slis);
    }
    // don't modify select list, as we add calculated elements
  }

  @Override
  protected void applyTopLevels(CqnTopLevelsTransformation topLevels) {
    wrapIfHas(Sql.WHERE, Sql.GROUPBY, Sql.AGGREGATE, Sql.HAVING, Sql.SKIP, Sql.TOP);

    Predicate filter = CQL.TRUE;
    // TODO: optimize to use depth instead where clause
    long levels = topLevels.levels();
    if (levels > 0) {
      // sub-select aliases {hierarchy_level} to {DistanceFromRoot = hierarchy_level -
      // 1}
      filter =
          select.from().isSelect()
              ? filter.and(CQL.get(DISTANCE_FROM_ROOT).lt(levels))
              : filter.and(CQL.get(HIERARCHY_LEVEL).le(levels));
    }
    CqnPredicate f = filter.or(expandFilter(topLevels.expandLevels()));

    // wrap into hierarchy select(hierarchy(select))
    List<CqnSortSpecification> siblingOrderBy = stashOrderBy();

    applyFilter(() -> f);

    Hierarchy hierarchy = HANA.hierarchy(select).orderBy(siblingOrderBy);
    select = Select.from(hierarchy).columns(computeVirtual(originalItems));
  }

  private List<CqnSelectListItem> computeVirtual(List<CqnSelectListItem> slis) {
    if (slis.isEmpty()) {
      List<CqnSelectListItem> selectList = new ArrayList<>(3);
      selectList.add(CQL.star());
      selectList.add(limitedDescendantCount());
      selectList.add(drillState());
      return selectList;
    }
    return slis.stream().map(this::computeVirtual).toList();
  }

  private CqnSelectListItem computeVirtual(CqnSelectListItem sli) {
    if (sli.isRef()) {
      String path = sli.asRef().path();
      return switch (path) {
        case LIMITED_DESCENDANT_COUNT -> limitedDescendantCount();
        case LIMITED_RANK -> limitedRank();
        case DRILL_STATE -> drillState();
        default -> sli;
      };
    }
    return sli;
  }

  private static CqnSelectListValue distanceFromRoot() {
    return CQL.get(HIERARCHY_LEVEL).minus(ONE).as(DISTANCE_FROM_ROOT);
  }

  private static CqnSelectListValue descendantCount() {
    return CQL.get(HIERARCHY_TREE_SIZE).minus(ONE).as(DESCENDANT_COUNT);
  }

  private static CqnSelectListValue limitedDescendantCount() {
    return limitedDescendantCountVal().as(LIMITED_DESCENDANT_COUNT);
  }

  private static Value<Number> limitedDescendantCountVal() {
    return CQL.get(HIERARCHY_TREE_SIZE).minus(ONE);
  }

  private static CqnSelectListValue rank() {
    return CQL.get(HIERARCHY_RANK).minus(ONE).as(RANK);
  }

  private static CqnSelectListValue limitedRank() {
    return CQL.get(HIERARCHY_RANK).minus(ONE).as(LIMITED_RANK);
  }

  private CqnSelectListItem drillState() {
    CqnPredicate noDescUnlimited = DESCENDANT_COUNT_REF.eq(ZERO);
    CqnPredicate noDescLimited = limitedDescendantCountVal().eq(ZERO);

    Value<?> drillState =
        CQL.when(noDescUnlimited).then(LEAF).when(noDescLimited).then(COLLAPSED).orElse(EXPANDED);

    return drillState.type(CdsBaseType.STRING).as(DRILL_STATE);
  }

  private static void addDescendantsDrillState(
      SelectBuilder<?> select, CqnDescendantsTransformation d) {
    int v = d.distanceFromStart();
    if (v > 1 || d.keepStart()) {
      return; // not yet implemented
    }

    CqnSelectListValue drillState =
        CQL.when(CQL.get(HIERARCHY_TREE_SIZE).gt(ONE))
            .then(COLLAPSED) //
            .orElse(LEAF) //
            .type(CdsBaseType.STRING)
            .as(DRILL_STATE);
    select.addItem(drillState);
  }

  /*
   * Identify the nodes to be expanded
   */
  private Predicate expandFilter(Map<Object, Long> ids) {
    List<Object> expandIdsOfLevelZero = getExpandIdsOfLevel(ids, 0L);
    List<Object> expandIdsOfLevelOne = getExpandIdsOfLevel(ids, 1L);
    List<Object> expandIdsOfLevelAll = getExpandIdsOfLevel(ids, LEVELS_ALL);

    List<Object> nodesAndDirectChildrenIds =
        Stream.concat(expandIdsOfLevelZero.stream(), expandIdsOfLevelOne.stream()).toList();
    // unexpanded nodes
    CqnPredicate expandNodesFilter = filterNodesOf(nodesAndDirectChildrenIds);
    // expand levels = 1 (direct children, flat)
    CqnPredicate expandOneLevelFilter = filterDirectDescendantsOf(expandIdsOfLevelOne);
    // expand levels = -1 (all children, deep)
    CqnPredicate expandAllLevelsFilter = filterDeepDescendantsOf(expandIdsOfLevelAll);

    return CQL.or(expandNodesFilter, CQL.or(expandOneLevelFilter, expandAllLevelsFilter));
  }

  private List<Object> getExpandIdsOfLevel(Map<Object, Long> ids, long level) {
    return ids.entrySet().stream()
        .filter(id -> id.getValue() == level)
        .map(id -> id.getKey())
        .toList();
  }

  // build a predicate for all (deep) descendants of the node
  private Predicate filterDeepDescendantsOf(List<Object> ids) {
    if (ids.isEmpty()) {
      return CQL.FALSE;
    }
    HierarchySubset descendants =
        HANA.descendants(sourceHierarchy)
            .startWhere(CQL.get(nodeId).in(ids))
            .distance(Integer.MAX_VALUE, true);
    return CQL.get(nodeId).in(Select.from(descendants).columns(nodeId));
  }

  // build a predicate for direct descendants (level 1)
  private Predicate filterDirectDescendantsOf(List<Object> expandIds) {
    return CQL.get(parentId).in(expandIds);
  }

  private Predicate filterNodesOf(List<Object> expandIds) {
    return CQL.get(nodeId).in(expandIds);
  }

  @Override
  protected void applyAncestors(CqnAncestorsTransformation transformation) {
    subSet(transformation, HANA::ancestors);
  }

  @Override
  protected void applyDescendants(CqnDescendantsTransformation transformation) {
    subSet(transformation, HANA::descendants);
  }

  private List<CqnSortSpecification> stashOrderBy() {
    var sob = List.copyOf(select.orderBy());
    select.orderBy(List.of());
    return sob;
  }

  private void subSet(
      CqnHierarchySubsetTransformation subsetTrafo, Function<Hierarchy, HierarchySubset> factory) {
    TrafoPredicatedVisitor v = new TrafoPredicatedVisitor();
    subsetTrafo.transformations().forEach(t -> t.accept(v));
    CqnPredicate startWhere = v.predicate;

    if (subsetTrafo.keepStart() && startWhere == CQL.TRUE) {
      // we keep all nodes
      return;
    }

    HierarchySubset subset = factory.apply(sourceHierarchy);

    subset.distance(subsetTrafo.distanceFromStart(), subsetTrafo.keepStart());
    subset.startWhere(startWhere);

    CqnSelect subquery = Select.from(subset).columns(nodeId).distinct().hints(select.hints());

    var sob = stashOrderBy();
    applyFilter(() -> InSubquery.in(CQL.get(nodeId), subquery));

    if (subset.isAncestors() && startWhere != CQL.TRUE) {
      // recompute descendant count for filtered hierarchy
      Hierarchy hierarchy =
          HANA.hierarchy(
              select.excluding(
                  HIERARCHY_TREE_SIZE, HIERARCHY_LEVEL, HIERARCHY_RANK, DESCENDANT_COUNT));
      hierarchy.orderBy(sob);

      select =
          Select.from(hierarchy)
              .columns(CQL.star(), descendantCount())
              .excluding(HIERARCHY_TREE_SIZE, HIERARCHY_LEVEL, HIERARCHY_RANK);
    }
    if (subset.isDescendants()) {
      select.orderBy(sob);
    }
  }

  private static class TrafoPredicatedVisitor implements CqnTransformationVisitor {
    CqnPredicate predicate = CQL.TRUE;

    @Override
    public void visit(CqnFilterTransformation filter) {
      predicate = CQL.and(predicate, filter.filter());
    }

    @Override
    public void visit(CqnSearchTransformation search) {
      predicate = CQL.and(predicate, search.search());
    }
  }
}
