/*
 * © 2018-2025 SAP SE or an SAP affiliate company. All rights reserved.
 */
package com.sap.cds.impl.sql;

import static com.sap.cds.impl.parser.token.CqnBoolLiteral.FALSE;
import static com.sap.cds.impl.parser.token.CqnBoolLiteral.TRUE;
import static java.util.stream.Collectors.joining;

import com.sap.cds.impl.Context;
import com.sap.cds.impl.PreparedCqnStmt.CqnParam;
import com.sap.cds.impl.PreparedCqnStmt.Parameter;
import com.sap.cds.impl.PreparedCqnStmt.ValueParam;
import com.sap.cds.impl.builder.model.Disjunction;
import com.sap.cds.impl.builder.model.ExistsSubquery;
import com.sap.cds.impl.qat.CTE;
import com.sap.cds.impl.qat.QatElementNode;
import com.sap.cds.impl.qat.QatEntityNode;
import com.sap.cds.impl.qat.QatEntityRootNode;
import com.sap.cds.impl.qat.QatSelectRootNode;
import com.sap.cds.impl.qat.QatSelectableNode;
import com.sap.cds.impl.qat.Ref2Column;
import com.sap.cds.impl.qat.Ref2Column.Clause;
import com.sap.cds.impl.sql.SQLStatementBuilder.SQLStatement;
import com.sap.cds.impl.sql.collate.Collating;
import com.sap.cds.impl.sql.collate.Collator;
import com.sap.cds.impl.util.Stack;
import com.sap.cds.jdbc.spi.DbContext;
import com.sap.cds.jdbc.spi.ScalarValueResolver;
import com.sap.cds.ql.cqn.CqnArithmeticExpression;
import com.sap.cds.ql.cqn.CqnBetweenPredicate;
import com.sap.cds.ql.cqn.CqnBooleanLiteral;
import com.sap.cds.ql.cqn.CqnCaseExpression;
import com.sap.cds.ql.cqn.CqnComparisonPredicate;
import com.sap.cds.ql.cqn.CqnConnectivePredicate;
import com.sap.cds.ql.cqn.CqnContainmentTest;
import com.sap.cds.ql.cqn.CqnElementRef;
import com.sap.cds.ql.cqn.CqnExistsSubquery;
import com.sap.cds.ql.cqn.CqnExpression;
import com.sap.cds.ql.cqn.CqnFunc;
import com.sap.cds.ql.cqn.CqnInPredicate;
import com.sap.cds.ql.cqn.CqnInSubquery;
import com.sap.cds.ql.cqn.CqnListValue;
import com.sap.cds.ql.cqn.CqnLiteral;
import com.sap.cds.ql.cqn.CqnNegation;
import com.sap.cds.ql.cqn.CqnNullValue;
import com.sap.cds.ql.cqn.CqnNumericLiteral;
import com.sap.cds.ql.cqn.CqnParameter;
import com.sap.cds.ql.cqn.CqnPlain;
import com.sap.cds.ql.cqn.CqnPredicate;
import com.sap.cds.ql.cqn.CqnSelect;
import com.sap.cds.ql.cqn.CqnSortSpecification;
import com.sap.cds.ql.cqn.CqnStringLiteral;
import com.sap.cds.ql.cqn.CqnToken;
import com.sap.cds.ql.cqn.CqnValue;
import com.sap.cds.ql.cqn.CqnVector;
import com.sap.cds.ql.cqn.CqnVisitor;
import com.sap.cds.ql.cqn.CqnWindowFunc;
import com.sap.cds.ql.impl.CqnNormalizer;
import com.sap.cds.ql.impl.Xpr;
import com.sap.cds.reflect.CdsBaseType;
import com.sap.cds.reflect.CdsEntity;
import com.sap.cds.reflect.CdsStructuredType;
import com.sap.cds.util.CdsModelUtils;
import com.sap.cds.util.CqnStatementUtils;
import com.sap.cds.util.SessionUtils;
import com.sap.cds.util.SessionUtils.SessionContextVariable;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;

public class TokenToSQLTransformer implements Function<CqnToken, String> {

  public static final String SQL_TRUE = "TRUE";
  public static final String SQL_FALSE = "FALSE";

  private static final int SUBSTRING_START_PARAM = 1;

  private final Context context;
  private final List<Parameter> params;
  private final Ref2Column aliasResolver;
  private final Deque<QatSelectableNode> outer;
  private final Map<String, CTE> ctes;
  private final boolean ctesInSubqueries;
  private final CqnNormalizer cqnNormalizer;
  private final ScalarValueResolver scalarValueResolver;
  private final Collator collator;

  private int parameterPosition = 0;

  public TokenToSQLTransformer(
      Context context,
      List<Parameter> params,
      Ref2Column aliasResolver,
      Deque<QatSelectableNode> outer,
      Map<String, CTE> ctes,
      boolean ctesInSubqueries,
      Collator collator) {
    this.context = context;
    this.params = params;
    this.aliasResolver = aliasResolver;
    this.outer = outer;
    this.ctes = ctes;
    this.ctesInSubqueries = ctesInSubqueries;
    this.collator = collator;
    this.cqnNormalizer = new CqnNormalizer(context);
    this.scalarValueResolver = context.getDbContext().getScalarValueResolver();
  }

  public static TokenToSQLTransformer notCollating(
      Context context,
      List<Parameter> params,
      Ref2Column aliasResolver,
      Deque<QatSelectableNode> outer) {
    return new TokenToSQLTransformer(
        context, params, aliasResolver, outer, new LinkedHashMap<>(), false, Collating.OFF);
  }

  public static TokenToSQLTransformer notCollating(
      Context context,
      Ref2Column aliasResolver,
      CdsEntity entity,
      String tableName,
      List<Parameter> params) {
    return new TokenToSQLTransformer(
        context,
        params,
        aliasResolver,
        outerQat(entity, tableName),
        new LinkedHashMap<>(),
        true,
        Collating.OFF);
  }

  public static TokenToSQLTransformer notCollating(
      Context context,
      Ref2Column aliasResolver,
      CqnSelect query,
      CdsStructuredType rowType,
      List<Parameter> params) {
    return new TokenToSQLTransformer(
        context,
        params,
        aliasResolver,
        outerQat(query, rowType),
        new LinkedHashMap<>(),
        true, // TODO check HierarchyFunctionMapper
        Collating.OFF);
  }

  private static Deque<QatSelectableNode> outerQat(CdsEntity entity, String tableName) {
    QatEntityNode root = new QatEntityRootNode(entity);
    CdsModelUtils.columnsOf(entity).forEach(e -> root.addChild(new QatElementNode(root, e)));

    root.setAlias(tableName);
    Deque<QatSelectableNode> outer = new ArrayDeque<>();
    outer.add(root);

    return outer;
  }

  private static Deque<QatSelectableNode> outerQat(CqnSelect query, CdsStructuredType rowType) {
    QatSelectRootNode root = new QatSelectRootNode(query, rowType);
    rowType.elements().forEach(e -> root.addChild(new QatElementNode(root, e)));

    root.setAlias("SQ");
    Deque<QatSelectableNode> outer = new ArrayDeque<>();
    outer.add(root);

    return outer;
  }

  public String toSQL(CdsStructuredType rowType, CqnPredicate pred) {
    DbContext dbContext = context.getDbContext();
    pred = dbContext.getPredicateMapper().apply(pred);
    pred = ContainsToLike.transform(dbContext.getFunctionMapper(), pred);
    pred = CqnStatementUtils.simplifyPredicate(rowType, pred);

    if (pred == TRUE) {
      return null;
    }

    if (pred == FALSE) {
      return "1 = 0";
    }

    return apply(pred);
  }

  public String toSQL(CqnPredicate pred) {
    return toSQL(null, pred);
  }

  public String selectColumn(CqnValue value) {
    value = scalarValueResolver.selectListValue(value);
    return visit(value, Clause.SELECT);
  }

  public String orderBy(CqnValue value) {
    return visit(value, Clause.ORDERBY);
  }

  public String sortSpec(CqnSortSpecification sortSpec) {
    return visit(sortSpec, Clause.ORDERBY);
  }

  @Override
  public String apply(CqnToken token) {
    return visit(token, Clause.WHERE);
  }

  private String visit(CqnToken token, Clause clause) {
    if (token == null) {
      throw new IllegalArgumentException("token must not be null");
    }
    ToSQLVisitor visitor = new ToSQLVisitor(clause);
    token.accept(visitor);
    String sql = visitor.get(token);
    if (token instanceof Xpr) {
      sql = sql.substring(1, sql.length() - 1);
    }

    return sql;
  }

  class ToSQLVisitor implements CqnVisitor {

    Stack<String> stack = new Stack<>();
    Clause clause;

    public ToSQLVisitor(Clause clause) {
      this.clause = clause;
    }

    @Override
    public void visit(CqnParameter p) {
      String name = p.isPositional() ? String.valueOf(parameterPosition++) : p.name();
      Parameter param = new CqnParam(name).type(p.type());

      params.add(param);

      push(scalarValueResolver.parameter(p));
    }

    @Override
    public void visit(CqnFunc cqnFunc) {
      String func = cqnFunc.func().toLowerCase(Locale.US);

      List<String> args = stack.pop(cqnFunc.args().size());
      if ("substring".equals(func)) {
        // increment the start pos as OData and CQN start at 0 but SQL starts at 1
        args.set(SUBSTRING_START_PARAM, args.get(SUBSTRING_START_PARAM) + " + 1");
      }

      push(context.getDbContext().getFunctionMapper().toSql(func, args));
    }

    @Override
    public void visit(CqnSortSpecification sortSpec) {
      StringBuilder sort = new StringBuilder(stack.pop());
      collator.collate(sortSpec).ifPresent(cc -> sort.append(" " + cc));
      sort.append(" " + sortOrderToSql(sortSpec));
      stack.push(sort.toString());
    }

    private static String sortOrderToSql(CqnSortSpecification o) {
      return switch (o.order()) {
        case DESC -> "DESC NULLS LAST";
        case DESC_NULLS_FIRST -> "DESC NULLS FIRST";
        case ASC_NULLS_LAST -> "NULLS LAST";
        default -> "NULLS FIRST";
      };
    }

    @Override
    public void visit(CqnWindowFunc cqnFunc) {
      String func = cqnFunc.func().toLowerCase(Locale.US);

      List<String> orders = stack.pop(cqnFunc.window().orderBy().size());
      List<String> partitions = stack.pop(cqnFunc.window().partitionBy().size());
      List<String> args = stack.pop(cqnFunc.args().size());

      StringBuilder sql =
          new StringBuilder(context.getDbContext().getFunctionMapper().toSql(func, args));

      sql.append(" OVER(");
      CqnWindowFunc.WindowSpecification ov = cqnFunc.window();
      if (!ov.partitionBy().isEmpty()) {
        sql.append("PARTITION BY ");
        sql.append(partitions.stream().collect(Collectors.joining(", ")));
        if (!ov.orderBy().isEmpty()) {
          sql.append(" ");
        }
      }
      if (!ov.orderBy().isEmpty()) {
        sql.append("ORDER BY ");
        sql.append(orders.stream().collect(Collectors.joining(", ")));
      }
      sql.append(")");

      push(sql.toString());
    }

    @Override
    public void visit(CqnListValue listValue) {
      int n = (int) listValue.values().count();

      String sql = stack.pop(n).stream().collect(joining(", ", "(", ")"));

      push(sql);
    }

    @Override
    public void visit(CqnContainmentTest test) {
      throw new IllegalStateException(
          "CQN containment test should be transformed to LIKE predicate");
    }

    @Override
    public void visit(CqnPlain plain) {
      push(plain.plain());
    }

    @Override
    public void visit(CqnBooleanLiteral bool) {
      push(scalarValueResolver.literal(bool));
    }

    @Override
    public void visit(CqnNumericLiteral<?> number) {
      Number val = number.value();
      if (number.isConstant()) {
        push(String.valueOf(val));
      } else {
        pushValueParam(number);
      }
    }

    @Override
    public void visit(CqnStringLiteral literal) {
      if (literal.isConstant()) {
        push(SQLHelper.literal(literal.value()));
      } else {
        pushValueParam(literal);
      }
    }

    @Override
    public void visit(CqnVector vector) {
      params.add(new ValueParam(vector::value).type(CdsBaseType.VECTOR));

      push(scalarValueResolver.parameter(vector));
    }

    @Override
    public void visit(CqnLiteral<?> literal) {
      pushValueParam(literal);
    }

    private void pushValueParam(CqnLiteral<?> literal) {
      Parameter param = new ValueParam(literal::value).type(literal.type());
      params.add(param);

      push(scalarValueResolver.literal(literal));
    }

    @Override
    public void visit(CqnNullValue nil) {
      push("NULL");
    }

    @Override
    public void visit(CqnElementRef ref) {
      Optional<SessionContextVariable> sParam =
          SessionUtils.getSessionParameter(ref, context.getSessionContext());
      if (sParam.isPresent()) {
        CdsBaseType type = sParam.get().getType();
        Parameter param = new ValueParam(sParam.get().getValueSupplier()).type(type);
        params.add(param);

        push(scalarValueResolver.parameter(type));
      } else {
        push(aliasResolver.apply(clause, ref).collect(joining(", ")));
      }
    }

    @Override
    public void visit(CqnExistsSubquery exists) {
      ExistsSubquery subQuery = (ExistsSubquery) exists;
      Deque<QatSelectableNode> outerNodes;
      if (subQuery.getOuter() != null) {
        // infix filter ?
        outerNodes = new ArrayDeque<>(outer);
        outerNodes.pop();
        outerNodes.add((QatSelectableNode) subQuery.getOuter());
      } else {
        // nested subquery
        outerNodes = outer;
      }
      CqnExistsSubquery normalized = cqnNormalizer.normalize(exists);
      SQLStatement stmt =
          SelectStatementBuilder.forSubquery(
                  context, params, normalized.subquery(), outerNodes, ctes, ctesInSubqueries)
              .build();
      push("EXISTS (" + stmt.sql() + ")");
    }

    @Override
    public void visit(CqnInSubquery in) {
      String value = stack.pop();
      CqnSelect subquery = cqnNormalizer.normalize(in.subquery());
      SQLStatement stmt =
          SelectStatementBuilder.forSubquery(
                  context, params, subquery, outer, ctes, ctesInSubqueries)
              .build();
      push(value + " IN (" + stmt.sql() + ")");
    }

    @Override
    public void visit(CqnExpression xpr) {
      List<String> snippets = stack.pop(((Xpr) xpr).size());
      push("(" + snippets.stream().collect(SpaceSeparatedCollector.joining()) + ")");
    }

    @Override
    public void visit(CqnBetweenPredicate between) {
      String high = stack.pop();
      String low = stack.pop();
      String value = stack.pop();
      push(value + " between " + low + " and " + high);
    }

    @Override
    public void visit(CqnArithmeticExpression expr) {
      String right = stack.pop();
      String left = stack.pop();

      push("(" + left + " " + expr.operator().symbol() + " " + right + ")");
    }

    @Override
    public void visit(CqnCaseExpression expr) {
      int n = expr.cases().size();
      StringBuilder sql = new StringBuilder("CASE");

      String result = stack.pop();
      Iterator<String> iter = stack.pop(n << 1).iterator();
      while (iter.hasNext()) {
        sql.append(" WHEN ").append(iter.next());
        sql.append(" THEN ").append(iter.next());
      }
      sql.append(" ELSE ").append(result).append(" END");

      push(sql.toString());
    }

    @Override
    public void visit(CqnComparisonPredicate comparison) {
      String right = stack.pop();
      String left = stack.pop();

      push(left + " " + comparison.operator().symbol + " " + right);
    }

    @Override
    public void visit(CqnConnectivePredicate connective) {
      String symbol = connective.operator().symbol;
      BiFunction<CqnPredicate, String, String> flattener =
          (connective.operator() == CqnConnectivePredicate.Operator.AND)
              ? this::flatAnd
              : this::flatOr;
      List<CqnPredicate> original = connective.predicates();
      List<String> snippets = stack.pop(original.size());
      StringBuilder sql = new StringBuilder();
      for (int i = 0; i < original.size(); i++) {
        if (i > 0) {
          sql.append(" ");
          sql.append(symbol);
          sql.append(" ");
        }
        sql.append(flattener.apply(original.get(i), snippets.get(i)));
      }
      push(sql.toString());
    }

    private String flatAnd(CqnPredicate pred, String snippet) {
      if (pred instanceof Disjunction) {
        return "(" + snippet + ")";
      }
      return snippet;
    }

    String flatOr(CqnPredicate pred, String snippet) {
      return snippet;
    }

    @Override
    public void visit(CqnInPredicate in) {
      String valueSet = stack.pop();
      String value = stack.pop();

      push(value + " in " + valueSet);
    }

    @Override
    public void visit(CqnNegation neg) {
      if (neg.predicate() instanceof CqnConnectivePredicate) {
        push("not (" + stack.pop() + ")");
      } else {
        push("not " + stack.pop());
      }
    }

    private void push(String snippet) {
      stack.push(snippet);
    }

    private String get(CqnToken t) {
      if (stack.size() != 1) {
        throw new IllegalStateException("token " + t.toJson() + " can't be mapped");
      }

      return stack.pop();
    }
  }
}
