/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.ir;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.trino.sql.ir.Array;
import io.trino.sql.ir.Between;
import io.trino.sql.ir.Bind;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Case;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Coalesce;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.ExpressionRewriter;
import io.trino.sql.ir.FieldReference;
import io.trino.sql.ir.In;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.IsNull;
import io.trino.sql.ir.Lambda;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.NullIf;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.Row;
import io.trino.sql.ir.Switch;
import io.trino.sql.ir.WhenClause;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;

public final class ExpressionTreeRewriter<C> {
    private final ExpressionRewriter<C> rewriter;
    private final IrVisitor<Expression, Context<C>> visitor;

    public static <T extends Expression> T rewriteWith(ExpressionRewriter<Void> rewriter, T node) {
        return new ExpressionTreeRewriter<Void>(rewriter).rewrite(node, null);
    }

    public static <C, T extends Expression> T rewriteWith(ExpressionRewriter<C> rewriter, T node, C context) {
        return new ExpressionTreeRewriter<C>(rewriter).rewrite(node, context);
    }

    public ExpressionTreeRewriter(ExpressionRewriter<C> rewriter) {
        this.rewriter = rewriter;
        this.visitor = new RewritingVisitor();
    }

    private List<Expression> rewrite(List<Expression> items, Context<C> context) {
        ImmutableList.Builder builder = ImmutableList.builder();
        for (Expression expression : items) {
            builder.add((Object)this.rewrite(expression, context.get()));
        }
        return builder.build();
    }

    public <T extends Expression> T rewrite(T node, C context) {
        return (T)this.visitor.process(node, new Context<C>(context, false));
    }

    public <T extends Expression> T defaultRewrite(T node, C context) {
        return (T)this.visitor.process(node, new Context<C>(context, true));
    }

    private static <T> boolean sameElements(Optional<T> a, Optional<T> b) {
        if (a.isEmpty() && b.isEmpty()) {
            return true;
        }
        if (a.isPresent() != b.isPresent()) {
            return false;
        }
        return a.get() == b.get();
    }

    private static <T> boolean sameElements(Iterable<? extends T> a, Iterable<? extends T> b) {
        if (Iterables.size(a) != Iterables.size(b)) {
            return false;
        }
        Iterator<T> first = a.iterator();
        Iterator<T> second = b.iterator();
        while (first.hasNext() && second.hasNext()) {
            if (first.next() == second.next()) continue;
            return false;
        }
        return true;
    }

    private class RewritingVisitor
    extends IrVisitor<Expression, Context<C>> {
        private RewritingVisitor() {
        }

        @Override
        protected Expression visitExpression(Expression node, Context<C> context) {
            throw new UnsupportedOperationException("visit() not implemented for " + node.getClass().getName());
        }

        @Override
        protected Expression visitArray(Array node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteArray(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            List<Expression> elements = ExpressionTreeRewriter.this.rewrite(node.elements(), context);
            if (!ExpressionTreeRewriter.sameElements(node.elements(), elements)) {
                return new Array(node.elementType(), elements);
            }
            return node;
        }

        @Override
        protected Expression visitRow(Row node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteRow(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            List<Expression> items = ExpressionTreeRewriter.this.rewrite(node.items(), context);
            if (!ExpressionTreeRewriter.sameElements(node.items(), items)) {
                return new Row(items, node.type());
            }
            return node;
        }

        @Override
        protected Expression visitFieldReference(FieldReference node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteSubscript(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            Expression base = ExpressionTreeRewriter.this.rewrite(node.base(), context.get());
            if (base != node.base()) {
                return new FieldReference(base, node.field());
            }
            return node;
        }

        @Override
        public Expression visitComparison(Comparison node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteComparison(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            Expression left = ExpressionTreeRewriter.this.rewrite(node.left(), context.get());
            Expression right = ExpressionTreeRewriter.this.rewrite(node.right(), context.get());
            if (left != node.left() || right != node.right()) {
                return new Comparison(node.operator(), left, right);
            }
            return node;
        }

        @Override
        protected Expression visitBetween(Between node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteBetween(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            Expression value = ExpressionTreeRewriter.this.rewrite(node.value(), context.get());
            Expression min = ExpressionTreeRewriter.this.rewrite(node.min(), context.get());
            Expression max = ExpressionTreeRewriter.this.rewrite(node.max(), context.get());
            if (value != node.value() || min != node.min() || max != node.max()) {
                return new Between(value, min, max);
            }
            return node;
        }

        @Override
        public Expression visitLogical(Logical node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteLogical(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            List<Expression> terms = ExpressionTreeRewriter.this.rewrite(node.terms(), context);
            if (!ExpressionTreeRewriter.sameElements(node.terms(), terms)) {
                return new Logical(node.operator(), terms);
            }
            return node;
        }

        @Override
        protected Expression visitIsNull(IsNull node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteIsNull(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            Expression value = ExpressionTreeRewriter.this.rewrite(node.value(), context.get());
            if (value != node.value()) {
                return new IsNull(value);
            }
            return node;
        }

        @Override
        protected Expression visitNullIf(NullIf node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteNullIf(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            Expression first = ExpressionTreeRewriter.this.rewrite(node.first(), context.get());
            Expression second = ExpressionTreeRewriter.this.rewrite(node.second(), context.get());
            if (first != node.first() || second != node.second()) {
                return new NullIf(first, second);
            }
            return node;
        }

        @Override
        protected Expression visitCase(Case node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteCase(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            ImmutableList.Builder builder = ImmutableList.builder();
            for (WhenClause expression : node.whenClauses()) {
                builder.add((Object)this.rewriteWhenClause(expression, context));
            }
            Expression defaultValue = ExpressionTreeRewriter.this.rewrite(node.defaultValue(), context.get());
            if (node.defaultValue() != defaultValue || !ExpressionTreeRewriter.sameElements(node.whenClauses(), builder.build())) {
                return new Case((List<WhenClause>)builder.build(), defaultValue);
            }
            return node;
        }

        @Override
        protected Expression visitSwitch(Switch node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteSwitch(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            Expression operand = ExpressionTreeRewriter.this.rewrite(node.operand(), context.get());
            ImmutableList.Builder builder = ImmutableList.builder();
            for (WhenClause expression : node.whenClauses()) {
                builder.add((Object)this.rewriteWhenClause(expression, context));
            }
            Expression defaultValue = ExpressionTreeRewriter.this.rewrite(node.defaultValue(), context.get());
            if (operand != node.operand() || node.defaultValue() != defaultValue || !ExpressionTreeRewriter.sameElements(node.whenClauses(), builder.build())) {
                return new Switch(operand, (List<WhenClause>)builder.build(), defaultValue);
            }
            return node;
        }

        protected WhenClause rewriteWhenClause(WhenClause node, Context<C> context) {
            Expression operand = ExpressionTreeRewriter.this.rewrite(node.getOperand(), context.get());
            Expression result = ExpressionTreeRewriter.this.rewrite(node.getResult(), context.get());
            if (operand != node.getOperand() || result != node.getResult()) {
                return new WhenClause(operand, result);
            }
            return node;
        }

        @Override
        protected Expression visitCoalesce(Coalesce node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteCoalesce(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            List<Expression> operands = ExpressionTreeRewriter.this.rewrite(node.operands(), context);
            if (!ExpressionTreeRewriter.sameElements(node.operands(), operands)) {
                return new Coalesce(operands);
            }
            return node;
        }

        @Override
        public Expression visitCall(Call node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteCall(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            List<Expression> arguments = ExpressionTreeRewriter.this.rewrite(node.arguments(), context);
            if (!ExpressionTreeRewriter.sameElements(node.arguments(), arguments)) {
                return new Call(node.function(), arguments);
            }
            return node;
        }

        @Override
        protected Expression visitLambda(Lambda node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteLambda(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            Expression body = ExpressionTreeRewriter.this.rewrite(node.body(), context.get());
            if (body != node.body()) {
                return new Lambda(node.arguments(), body);
            }
            return node;
        }

        @Override
        protected Expression visitBind(Bind node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteBind(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            List values = (List)node.values().stream().map(value -> ExpressionTreeRewriter.this.rewrite(value, context.get())).collect(ImmutableList.toImmutableList());
            Lambda function = ExpressionTreeRewriter.this.rewrite(node.function(), context.get());
            if (!ExpressionTreeRewriter.sameElements(values, node.values()) || function != node.function()) {
                return new Bind(values, function);
            }
            return node;
        }

        @Override
        public Expression visitIn(In node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteIn(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            Expression value = ExpressionTreeRewriter.this.rewrite(node.value(), context.get());
            List values = (List)node.valueList().stream().map(entry -> ExpressionTreeRewriter.this.rewrite(entry, context.get())).collect(ImmutableList.toImmutableList());
            if (node.value() != value || !ExpressionTreeRewriter.sameElements(values, node.valueList())) {
                return new In(value, values);
            }
            return node;
        }

        @Override
        public Expression visitConstant(Constant node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteConstant(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            return node;
        }

        @Override
        public Expression visitCast(Cast node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteCast(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            Expression expression = ExpressionTreeRewriter.this.rewrite(node.expression(), context.get());
            if (node.expression() != expression) {
                return new Cast(expression, node.type());
            }
            return node;
        }

        @Override
        protected Expression visitReference(Reference node, Context<C> context) {
            Expression result;
            if (!context.isDefaultRewrite() && (result = ExpressionTreeRewriter.this.rewriter.rewriteReference(node, context.get(), ExpressionTreeRewriter.this)) != null) {
                return result;
            }
            return node;
        }
    }

    public static class Context<C> {
        private final boolean defaultRewrite;
        private final C context;

        private Context(C context, boolean defaultRewrite) {
            this.context = context;
            this.defaultRewrite = defaultRewrite;
        }

        public C get() {
            return this.context;
        }

        public boolean isDefaultRewrite() {
            return this.defaultRewrite;
        }
    }
}

