/**************************************************************************
 * (C) 2020-2023 SAP SE or an SAP affiliate company. All rights reserved. *
 **************************************************************************/
package com.sap.cds.impl.sql;

import static com.sap.cds.impl.parser.token.CqnBoolLiteral.FALSE;
import static com.sap.cds.impl.parser.token.CqnBoolLiteral.TRUE;
import static java.util.stream.Collectors.joining;

import java.util.ArrayDeque;
import java.util.Deque;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;

import com.sap.cds.impl.Context;
import com.sap.cds.impl.PreparedCqnStmt.CqnParam;
import com.sap.cds.impl.PreparedCqnStmt.Parameter;
import com.sap.cds.impl.PreparedCqnStmt.ValueParam;
import com.sap.cds.impl.builder.model.Disjunction;
import com.sap.cds.impl.builder.model.ExistsSubquery;
import com.sap.cds.impl.qat.QatElementNode;
import com.sap.cds.impl.qat.QatEntityNode;
import com.sap.cds.impl.qat.QatEntityRootNode;
import com.sap.cds.impl.qat.QatSelectableNode;
import com.sap.cds.impl.sql.SQLStatementBuilder.SQLStatement;
import com.sap.cds.impl.util.Stack;
import com.sap.cds.jdbc.spi.DbContext;
import com.sap.cds.ql.cqn.CqnArithmeticExpression;
import com.sap.cds.ql.cqn.CqnBooleanLiteral;
import com.sap.cds.ql.cqn.CqnComparisonPredicate;
import com.sap.cds.ql.cqn.CqnConnectivePredicate;
import com.sap.cds.ql.cqn.CqnContainmentTest;
import com.sap.cds.ql.cqn.CqnElementRef;
import com.sap.cds.ql.cqn.CqnExistsSubquery;
import com.sap.cds.ql.cqn.CqnExpression;
import com.sap.cds.ql.cqn.CqnFunc;
import com.sap.cds.ql.cqn.CqnInPredicate;
import com.sap.cds.ql.cqn.CqnListValue;
import com.sap.cds.ql.cqn.CqnLiteral;
import com.sap.cds.ql.cqn.CqnNegation;
import com.sap.cds.ql.cqn.CqnNullValue;
import com.sap.cds.ql.cqn.CqnNumericLiteral;
import com.sap.cds.ql.cqn.CqnParameter;
import com.sap.cds.ql.cqn.CqnPlain;
import com.sap.cds.ql.cqn.CqnPredicate;
import com.sap.cds.ql.cqn.CqnStringLiteral;
import com.sap.cds.ql.cqn.CqnToken;
import com.sap.cds.ql.cqn.CqnVisitor;
import com.sap.cds.ql.impl.CqnNormalizer;
import com.sap.cds.ql.impl.Xpr;
import com.sap.cds.reflect.CdsBaseType;
import com.sap.cds.reflect.CdsEntity;
import com.sap.cds.util.CqnStatementUtils;
import com.sap.cds.util.SessionUtils;
import com.sap.cds.util.SessionUtils.SessionContextVariable;

public class TokenToSQLTransformer implements Function<CqnToken, String> {

	public static final String SQL_TRUE = "TRUE";
	public static final String SQL_FALSE = "FALSE";

	private static final int SUBSTRING_START_PARAM = 1;

	private final Context context;
	private final List<Parameter> params;
	private final Function<CqnElementRef, String> aliasResolver;
	private final Deque<QatSelectableNode> outer;
	private final boolean noCollating;
	private final Function<Parameter, String> paramResolver;
	private final CqnNormalizer cqnNormalizer;

	private int parameterPosition = 0;

	public TokenToSQLTransformer(Context context, List<Parameter> params, Function<CqnElementRef, String> aliasResolver,
			Deque<QatSelectableNode> outer, Function<Parameter, String> paramResolver, boolean noCollating) {
		this.context = context;
		this.params = params;
		this.aliasResolver = aliasResolver;
		this.outer = outer;
		this.noCollating = noCollating;
		this.paramResolver = paramResolver;
		this.cqnNormalizer = new CqnNormalizer(context);
	}

	public TokenToSQLTransformer(Context context, List<Parameter> params, Function<CqnElementRef, String> aliasResolver,
			Deque<QatSelectableNode> outer, boolean noCollating) {
		this(context, params, aliasResolver, outer, parameter -> "?", noCollating);
	}

	public static TokenToSQLTransformer notCollating(Context context, Function<CqnElementRef, String> aliasResolver, CdsEntity entity,
			String tableName, List<Parameter> params, Function<Parameter, String> paramResolver) {
		return new TokenToSQLTransformer(context, params, aliasResolver, outerQat(entity, tableName), paramResolver, true);
	}

	public static TokenToSQLTransformer notCollating(Context context, List<Parameter> params, Function<CqnElementRef, String> aliasResolver,
			Deque<QatSelectableNode> outer) {
		return new TokenToSQLTransformer(context, params, aliasResolver, outer, true);
	}

	public static TokenToSQLTransformer notCollating(Context context, Function<CqnElementRef, String> aliasResolver, CdsEntity entity,
			String tableName, List<Parameter> params) {
		return new TokenToSQLTransformer(context, params, aliasResolver, outerQat(entity, tableName), true);
	}

	private static Deque<QatSelectableNode> outerQat(CdsEntity entity, String tableName) {
		QatEntityNode root = new QatEntityRootNode(entity);
		entity.concreteNonAssociationElements().forEach(e -> root.addChild(new QatElementNode(root, e)));
		root.setAlias(tableName);
		Deque<QatSelectableNode> outer = new ArrayDeque<>();
		outer.add(root);
		return outer;
	}

	public String toSQL(CqnPredicate pred) {
		DbContext dbContext = context.getDbContext();
		pred = dbContext.getPredicateMapper().apply(pred);
		pred = ContainsToLike.transform(dbContext.getFunctionMapper(), pred);
		pred = CqnStatementUtils.simplifyPredicate(pred);

		if (pred == TRUE) {
			return null;
		}

		if (pred == FALSE) {
			return "1 = 0";
		}

		return apply(pred);
	}

	@Override
	public String apply(CqnToken pred) {
		if (pred == null) {
			throw new IllegalArgumentException("predicate must not be null");
		}
		ToSQLVisitor visitor = new ToSQLVisitor();
		pred.accept(visitor);
		String sql = visitor.get(pred);
		if (pred instanceof Xpr) {
			sql = sql.substring(1, sql.length() - 1);
		}

		return sql;
	}

	class ToSQLVisitor implements CqnVisitor {

		Stack<String> stack = new Stack<>();

		@Override
		public void visit(CqnParameter p) {
			String name = p.isPositional() ? String.valueOf(parameterPosition++) : p.name();
			Parameter param = new CqnParam(name).type(p.type());

			params.add(param);

			push(paramResolver.apply(param));
		}

		@Override
		public void visit(CqnFunc cqnFunc) {
			String func = cqnFunc.func().toLowerCase(Locale.US);

			List<String> args = stack.pop(cqnFunc.args().size());
			if ("substring".equals(func)) {
				// increment the start pos as OData and CQN start at 0 but SQL starts at 1
				args.set(SUBSTRING_START_PARAM, args.get(SUBSTRING_START_PARAM) + " + 1");
			}

			push(context.getDbContext().getFunctionMapper().toSql(func, args));
		}

		@Override
		public void visit(CqnListValue listValue) {
			int n = (int) listValue.values().count();

			String sql = stack.pop(n).stream().collect(joining(", ", "(", ")"));

			push(sql);
		}

		@Override
		public void visit(CqnContainmentTest test) {
			throw new IllegalStateException("CQN containment test should be transformed to LIKE predicate");
		}

		@Override
		public void visit(CqnPlain plain) {
			push(plain.plain());
		}

		@Override
		public void visit(CqnBooleanLiteral bool) {
			push(Boolean.TRUE.equals(bool.value()) ? SQL_TRUE : SQL_FALSE);
		}

		@Override
		public void visit(CqnNumericLiteral<?> number) {
			Number val = number.value();
			if (number.isConstant() && isNonDecimal(val)) {
				push(String.valueOf(val));
			} else {
				valueParam(number);
			}
		}

		private boolean isNonDecimal(Number number) {
			return number instanceof Integer || number instanceof Long || number instanceof Short;
		}

		@Override
		public void visit(CqnStringLiteral literal) {
			if (literal.isConstant()) {
				push(SQLHelper.literal(literal.value()));
			} else {
				valueParam(literal);
			}
		}

		@Override
		public void visit(CqnLiteral<?> literal) {
			valueParam(literal);
		}

		private void valueParam(CqnLiteral<?> literal) {
			Parameter param = new ValueParam(literal::value).type(literal.type());
			params.add(param);

			push(paramResolver.apply(param));
		}

		private void valueParam(Supplier<Object> valueSupplier, CdsBaseType type) {
			Parameter param = new ValueParam(valueSupplier).type(type);
			params.add(param);

			push(paramResolver.apply(param));
		}

		@Override
		public void visit(CqnNullValue nil) {
			push("NULL");
		}

		@Override
		public void visit(CqnElementRef ref) {
			Optional<SessionContextVariable> sParam = SessionUtils.getSessionParameter(ref,
					context.getSessionContext());
			if (sParam.isPresent()) {
				valueParam(sParam.get().getValueSupplier(), sParam.get().getType());
			} else { // check if ref is assoc and return fk/ref columns?
				push(aliasResolver.apply(ref));
			}
		}

		@Override
		public void visit(CqnExistsSubquery exists) {
			ExistsSubquery subQuery = (ExistsSubquery) exists;
			Deque<QatSelectableNode> outerNodes;
			if (subQuery.getOuter() != null) {
				// infix filter ?
				outerNodes = new ArrayDeque<>(outer);
				outerNodes.pop();
				outerNodes.add((QatSelectableNode) subQuery.getOuter());
			} else {
				// nested subquery
				outerNodes = outer;
			}
			CqnExistsSubquery normalized = cqnNormalizer.normalize(exists);
			SQLStatement stmt = new SelectStatementBuilder(context, params, normalized.subquery(), outerNodes,
					noCollating).build();
			push("EXISTS (" + stmt.sql() + ")");
		}

		@Override
		public void visit(CqnExpression xpr) {
			List<String> snippets = stack.pop(((Xpr) xpr).length());
			push("(" + snippets.stream().collect(SpaceSeparatedCollector.joining()) + ")");
		}

		@Override
		public void visit(CqnArithmeticExpression expr) {
			String right = stack.pop();
			String left = stack.pop();

			push("(" + left + " " + expr.operator().symbol() + " " + right + ")");
		}

		@Override
		public void visit(CqnComparisonPredicate comparison) {
			String right = stack.pop();
			String left = stack.pop();

			push(left + " " + comparison.operator().symbol + " " + right);
		}

		@Override
		public void visit(CqnConnectivePredicate connective) {
			String symbol = connective.operator().symbol;
			BiFunction<CqnPredicate, String, String> flattener = (connective
					.operator() == CqnConnectivePredicate.Operator.AND) ? this::flatAnd : this::flatOr;
			List<CqnPredicate> original = connective.predicates();
			List<String> snippets = stack.pop(original.size());
			StringBuilder sql = new StringBuilder();
			for (int i = 0; i < original.size(); i++) {
				if (i > 0) {
					sql.append(" ");
					sql.append(symbol);
					sql.append(" ");
				}
				sql.append(flattener.apply(original.get(i), snippets.get(i)));

			}
			push(sql.toString());
		}

		private String flatAnd(CqnPredicate pred, String snippet) {
			if (pred instanceof Disjunction) {
				return "(" + snippet + ")";

			}
			return snippet;
		}

		String flatOr(CqnPredicate pred, String snippet) {
			return snippet;
		}

		@Override
		public void visit(CqnInPredicate in) {
			String valueSet = stack.pop();
			String value = stack.pop();

			push(value + " in " + valueSet);
		}

		@Override
		public void visit(CqnNegation neg) {
			if (neg.predicate() instanceof CqnConnectivePredicate) {
				push("not (" + stack.pop() + ")");
			} else {
				push("not " + stack.pop());
			}
		}

		private void push(String snippet) {
			stack.push(snippet);
		}

		private String get(CqnToken pred) {
			if (stack.size() != 1) {
				throw new IllegalStateException("token " + pred.toJson() + " can't be mapped");
			}

			return stack.pop();
		}

	}

}
