/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.ir;

import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import io.trino.sql.ir.ArithmeticBinaryExpression;
import io.trino.sql.ir.ArithmeticNegation;
import io.trino.sql.ir.BetweenPredicate;
import io.trino.sql.ir.BindExpression;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.CoalesceExpression;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.ir.InPredicate;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.IsNullPredicate;
import io.trino.sql.ir.LambdaExpression;
import io.trino.sql.ir.LogicalExpression;
import io.trino.sql.ir.NotExpression;
import io.trino.sql.ir.NullIfExpression;
import io.trino.sql.ir.Row;
import io.trino.sql.ir.SearchedCaseExpression;
import io.trino.sql.ir.SimpleCaseExpression;
import io.trino.sql.ir.SubscriptExpression;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.planner.Symbol;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

public final class ExpressionFormatter {
    private ExpressionFormatter() {
    }

    public static String formatExpression(Expression expression) {
        return (String)new Formatter(Optional.empty(), Optional.empty()).process(expression, null);
    }

    public static class Formatter
    extends IrVisitor<String, Void> {
        private final Optional<Function<Constant, String>> literalFormatter;
        private final Optional<Function<SymbolReference, String>> symbolReferenceFormatter;

        public Formatter(Optional<Function<Constant, String>> literalFormatter, Optional<Function<SymbolReference, String>> symbolReferenceFormatter) {
            this.literalFormatter = Objects.requireNonNull(literalFormatter, "literalFormatter is null");
            this.symbolReferenceFormatter = Objects.requireNonNull(symbolReferenceFormatter, "symbolReferenceFormatter is null");
        }

        @Override
        protected String visitRow(Row node, Void context) {
            return node.getItems().stream().map(child -> (String)this.process((Expression)child, context)).collect(Collectors.joining(", ", "ROW (", ")"));
        }

        @Override
        protected String visitExpression(Expression node, Void context) {
            throw new UnsupportedOperationException("not yet implemented: %s.visit%s".formatted(this.getClass().getName(), node.getClass().getSimpleName()));
        }

        @Override
        protected String visitSubscriptExpression(SubscriptExpression node, Void context) {
            return ExpressionFormatter.formatExpression(node.getBase()) + "[" + ExpressionFormatter.formatExpression(node.getIndex()) + "]";
        }

        @Override
        protected String visitConstant(Constant node, Void context) {
            return this.literalFormatter.map(formatter -> (String)formatter.apply(node)).orElseGet(() -> {
                if (node.getValue() == null) {
                    return "null::" + String.valueOf(node.getType());
                }
                return String.valueOf(node.getType()) + " '" + String.valueOf(node.getType().getObjectValue(null, node.getValueAsBlock(), 0)) + "'";
            });
        }

        @Override
        protected String visitFunctionCall(FunctionCall node, Void context) {
            return node.getFunction().getName().toString() + "(" + this.joinExpressions(node.getArguments()) + ")";
        }

        @Override
        protected String visitLambdaExpression(LambdaExpression node, Void context) {
            return "(" + node.arguments().stream().map(Symbol::getName).collect(Collectors.joining(", ")) + ") -> " + (String)this.process(node.getBody(), context);
        }

        @Override
        protected String visitSymbolReference(SymbolReference node, Void context) {
            if (this.symbolReferenceFormatter.isPresent()) {
                return this.symbolReferenceFormatter.get().apply(node);
            }
            return node.getName();
        }

        @Override
        protected String visitBindExpression(BindExpression node, Void context) {
            StringBuilder builder = new StringBuilder();
            builder.append("\"$bind\"(");
            for (Expression value : node.getValues()) {
                builder.append((String)this.process(value, context)).append(", ");
            }
            builder.append((String)this.process(node.getFunction(), context)).append(")");
            return builder.toString();
        }

        @Override
        protected String visitLogicalExpression(LogicalExpression node, Void context) {
            return "(" + node.getTerms().stream().map(term -> (String)this.process((Expression)term, context)).collect(Collectors.joining(" " + node.getOperator().toString() + " ")) + ")";
        }

        @Override
        protected String visitNotExpression(NotExpression node, Void context) {
            return "(NOT " + (String)this.process(node.getValue(), context) + ")";
        }

        @Override
        protected String visitComparisonExpression(ComparisonExpression node, Void context) {
            return this.formatBinaryExpression(node.getOperator().getValue(), node.getLeft(), node.getRight());
        }

        @Override
        protected String visitIsNullPredicate(IsNullPredicate node, Void context) {
            return "(" + (String)this.process(node.getValue(), context) + " IS NULL)";
        }

        @Override
        protected String visitNullIfExpression(NullIfExpression node, Void context) {
            return "NULLIF(" + (String)this.process(node.getFirst(), context) + ", " + (String)this.process(node.getSecond(), context) + ")";
        }

        @Override
        protected String visitCoalesceExpression(CoalesceExpression node, Void context) {
            return "COALESCE(" + this.joinExpressions(node.getOperands()) + ")";
        }

        @Override
        protected String visitArithmeticNegation(ArithmeticNegation node, Void context) {
            return "-(" + (String)this.process(node.getValue(), context) + ")";
        }

        @Override
        protected String visitArithmeticBinary(ArithmeticBinaryExpression node, Void context) {
            return this.formatBinaryExpression(node.getOperator().getValue(), node.getLeft(), node.getRight());
        }

        @Override
        public String visitCast(Cast node, Void context) {
            return (node.isSafe() ? "TRY_CAST" : "CAST") + "(" + (String)this.process(node.getExpression(), context) + " AS " + node.getType().getDisplayName() + ")";
        }

        @Override
        protected String visitSearchedCaseExpression(SearchedCaseExpression node, Void context) {
            ImmutableList.Builder parts = ImmutableList.builder();
            parts.add((Object)"CASE");
            for (WhenClause whenClause : node.getWhenClauses()) {
                parts.add((Object)this.format(whenClause, context));
            }
            node.getDefaultValue().ifPresent(value -> parts.add((Object)"ELSE").add((Object)((String)this.process((Expression)value, context))));
            parts.add((Object)"END");
            return "(" + Joiner.on((char)' ').join((Iterable)parts.build()) + ")";
        }

        @Override
        protected String visitSimpleCaseExpression(SimpleCaseExpression node, Void context) {
            ImmutableList.Builder parts = ImmutableList.builder();
            parts.add((Object)"CASE").add((Object)((String)this.process(node.getOperand(), context)));
            for (WhenClause whenClause : node.getWhenClauses()) {
                parts.add((Object)this.format(whenClause, context));
            }
            node.getDefaultValue().ifPresent(value -> parts.add((Object)"ELSE").add((Object)((String)this.process((Expression)value, context))));
            parts.add((Object)"END");
            return "(" + Joiner.on((char)' ').join((Iterable)parts.build()) + ")";
        }

        protected String format(WhenClause node, Void context) {
            return "WHEN " + (String)this.process(node.getOperand(), context) + " THEN " + (String)this.process(node.getResult(), context);
        }

        @Override
        protected String visitBetweenPredicate(BetweenPredicate node, Void context) {
            return "(" + (String)this.process(node.getValue(), context) + " BETWEEN " + (String)this.process(node.getMin(), context) + " AND " + (String)this.process(node.getMax(), context) + ")";
        }

        @Override
        protected String visitInPredicate(InPredicate node, Void context) {
            return "(" + (String)this.process(node.getValue(), context) + " IN " + this.joinExpressions(node.getValueList()) + ")";
        }

        private String formatBinaryExpression(String operator, Expression left, Expression right) {
            return "(" + (String)this.process(left, null) + " " + operator + " " + (String)this.process(right, null) + ")";
        }

        private String joinExpressions(List<Expression> expressions) {
            return expressions.stream().map(e -> (String)this.process((Expression)e, null)).collect(Collectors.joining(", "));
        }
    }
}

