/*
 * Decompiled with CFR 0.152.
 */
package org.partiql.plan;

import java.util.ArrayList;
import java.util.List;
import org.jetbrains.annotations.NotNull;
import org.partiql.plan.Collation;
import org.partiql.plan.Exclusion;
import org.partiql.plan.Operator;
import org.partiql.plan.OperatorVisitor;
import org.partiql.plan.Operators;
import org.partiql.plan.rel.Rel;
import org.partiql.plan.rel.RelAggregate;
import org.partiql.plan.rel.RelCorrelate;
import org.partiql.plan.rel.RelDistinct;
import org.partiql.plan.rel.RelExcept;
import org.partiql.plan.rel.RelExclude;
import org.partiql.plan.rel.RelFilter;
import org.partiql.plan.rel.RelIntersect;
import org.partiql.plan.rel.RelIterate;
import org.partiql.plan.rel.RelJoin;
import org.partiql.plan.rel.RelLimit;
import org.partiql.plan.rel.RelOffset;
import org.partiql.plan.rel.RelProject;
import org.partiql.plan.rel.RelScan;
import org.partiql.plan.rel.RelSort;
import org.partiql.plan.rel.RelUnion;
import org.partiql.plan.rel.RelUnpivot;
import org.partiql.plan.rex.Rex;
import org.partiql.plan.rex.RexArray;
import org.partiql.plan.rex.RexBag;
import org.partiql.plan.rex.RexCall;
import org.partiql.plan.rex.RexCase;
import org.partiql.plan.rex.RexCast;
import org.partiql.plan.rex.RexCoalesce;
import org.partiql.plan.rex.RexDispatch;
import org.partiql.plan.rex.RexError;
import org.partiql.plan.rex.RexLit;
import org.partiql.plan.rex.RexNullIf;
import org.partiql.plan.rex.RexPathIndex;
import org.partiql.plan.rex.RexPathKey;
import org.partiql.plan.rex.RexPathSymbol;
import org.partiql.plan.rex.RexPivot;
import org.partiql.plan.rex.RexSelect;
import org.partiql.plan.rex.RexSpread;
import org.partiql.plan.rex.RexStruct;
import org.partiql.plan.rex.RexSubquery;
import org.partiql.plan.rex.RexSubqueryComp;
import org.partiql.plan.rex.RexSubqueryIn;
import org.partiql.plan.rex.RexSubqueryTest;
import org.partiql.plan.rex.RexTable;
import org.partiql.plan.rex.RexVar;

public abstract class OperatorRewriter<C>
implements OperatorVisitor<Operator, C> {
    private final Operators operators;

    public OperatorRewriter() {
        this.operators = Operators.STANDARD;
    }

    public OperatorRewriter(Operators operators) {
        this.operators = operators;
    }

    @Override
    @NotNull
    public Operator defaultReturn(@NotNull Operator operator, C ctx) {
        return operator;
    }

    @Override
    @NotNull
    public Operator defaultVisit(Operator operator, C ctx) {
        return operator;
    }

    @Override
    @NotNull
    public Operator visitAggregate(@NotNull RelAggregate rel, C ctx) {
        Rel input = rel.getInput();
        Rel input_new = this.visit(input, ctx, Rel.class);
        List<RelAggregate.Measure> measures = rel.getMeasures();
        List<RelAggregate.Measure> measures_new = this.visitAll(measures, ctx, this::visitAggregateMeasure);
        List<Rex> groups = rel.getGroups();
        List<Rex> groups_new = this.visitAll(groups, ctx, this::visitAggregateGroup);
        if (input != input_new || measures != measures_new || groups != groups_new) {
            return this.operators.aggregate(input_new, measures_new, groups_new);
        }
        return rel;
    }

    @NotNull
    public RelAggregate.Measure visitAggregateMeasure(@NotNull RelAggregate.Measure measure, C ctx) {
        List<Rex> args_new;
        List<Rex> args = measure.getArgs();
        if (args != (args_new = this.visitAll(args, ctx, this::visitAggregateGroup))) {
            return measure.copy(args_new);
        }
        return measure;
    }

    @NotNull
    public Rex visitAggregateGroup(@NotNull Rex rex, C ctx) {
        return this.visit(rex, ctx, Rex.class);
    }

    @Override
    public Operator visitCorrelate(@NotNull RelCorrelate rel, C ctx) {
        Rel left = rel.getLeft();
        Rel left_new = this.visit(left, ctx, Rel.class);
        Rel right = rel.getRight();
        Rel right_new = this.visit(right, ctx, Rel.class);
        if (left != left_new || right != right_new) {
            return this.operators.correlate(left_new, right_new, rel.getJoinType());
        }
        return rel;
    }

    @Override
    @NotNull
    public Operator visitDistinct(@NotNull RelDistinct rel, C ctx) {
        Rel input_new;
        Rel input = rel.getInput();
        if (input != (input_new = this.visit(input, ctx, Rel.class))) {
            return this.operators.distinct(input_new);
        }
        return rel;
    }

    @Override
    @NotNull
    public Operator visitExcept(@NotNull RelExcept rel, C ctx) {
        Rel left = rel.getLeft();
        Rel left_new = this.visit(left, ctx, Rel.class);
        Rel right = rel.getRight();
        Rel right_new = this.visit(right, ctx, Rel.class);
        if (left != left_new || right != right_new) {
            return this.operators.except(left_new, right_new, rel.isAll());
        }
        return rel;
    }

    @Override
    @NotNull
    public Operator visitExclude(@NotNull RelExclude rel, C ctx) {
        Rel input = rel.getInput();
        Rel input_new = this.visit(input, ctx, Rel.class);
        List<Exclusion> exclusions = rel.getExclusions();
        List<Exclusion> exclusions_new = this.visitAll(exclusions, ctx, this::visitExclusions);
        if (input != input_new) {
            return this.operators.exclude(input_new, exclusions_new);
        }
        return rel;
    }

    @NotNull
    public Exclusion visitExclusions(@NotNull Exclusion exclusion, C ctx) {
        return exclusion;
    }

    @Override
    public Operator visitFilter(@NotNull RelFilter rel, C ctx) {
        Rel input = rel.getInput();
        Rel input_new = this.visit(input, ctx, Rel.class);
        Rex predicate = rel.getPredicate();
        Rex predicate_new = this.visit(predicate, ctx, Rex.class);
        if (input != input_new || predicate != predicate_new) {
            return this.operators.filter(input_new, predicate_new);
        }
        return rel;
    }

    @Override
    public Operator visitIntersect(@NotNull RelIntersect rel, C ctx) {
        Rel left = rel.getLeft();
        Rel left_new = this.visit(left, ctx, Rel.class);
        Rel right = rel.getRight();
        Rel right_new = this.visit(right, ctx, Rel.class);
        if (left != left_new || right != right_new) {
            return this.operators.intersect(left_new, right_new, rel.isAll());
        }
        return rel;
    }

    @Override
    public Operator visitIterate(@NotNull RelIterate rel, C ctx) {
        Rex rex_new;
        Rex rex = rel.getRex();
        if (rex != (rex_new = this.visit(rex, ctx, Rex.class))) {
            return this.operators.iterate(rex_new);
        }
        return rel;
    }

    @Override
    public Operator visitJoin(@NotNull RelJoin rel, C ctx) {
        Rel left = rel.getLeft();
        Rel left_new = this.visit(left, ctx, Rel.class);
        Rel right = rel.getRight();
        Rel right_new = this.visit(right, ctx, Rel.class);
        Rex condition = rel.getCondition();
        Rex condition_new = this.visit(condition, ctx, Rex.class);
        if (left != left_new || right != right_new || condition != condition_new) {
            return this.operators.join(left_new, right_new, condition_new, rel.getJoinType());
        }
        return rel;
    }

    @Override
    public Operator visitLimit(@NotNull RelLimit rel, C ctx) {
        Rel input = rel.getInput();
        Rel input_new = this.visit(input, ctx, Rel.class);
        Rex limit = rel.getLimit();
        Rex limit_new = this.visit(limit, ctx, Rex.class);
        if (input != input_new || limit != limit_new) {
            return this.operators.limit(input_new, limit_new);
        }
        return rel;
    }

    @Override
    public Operator visitOffset(@NotNull RelOffset rel, C ctx) {
        Rel input = rel.getInput();
        Rel input_new = this.visit(input, ctx, Rel.class);
        Rex offset = rel.getOffset();
        Rex offset_new = this.visit(offset, ctx, Rex.class);
        if (input != input_new || offset != offset_new) {
            return this.operators.offset(input_new, offset_new);
        }
        return rel;
    }

    @Override
    public Operator visitProject(@NotNull RelProject rel, C ctx) {
        Rel input = rel.getInput();
        Rel input_new = this.visit(input, ctx, Rel.class);
        List<Rex> projections = rel.getProjections();
        List<Rex> projects_new = this.visitAll(projections, ctx, this::visitProjection);
        if (input != input_new || projections != projects_new) {
            return this.operators.project(input_new, projects_new);
        }
        return rel;
    }

    @NotNull
    public Rex visitProjection(@NotNull Rex rex, C ctx) {
        return this.visit(rex, ctx, Rex.class);
    }

    @Override
    public Operator visitScan(@NotNull RelScan rel, C ctx) {
        Rex rex_new;
        Rex rex = rel.getRex();
        if (rex != (rex_new = this.visit(rex, ctx, Rex.class))) {
            return this.operators.scan(rex_new);
        }
        return rel;
    }

    @Override
    public Operator visitSort(@NotNull RelSort rel, C ctx) {
        Rel input = rel.getInput();
        Rel input_new = this.visit(input, ctx, Rel.class);
        List<Collation> collations = rel.getCollations();
        List<Collation> collations_new = this.visitAll(collations, ctx, this::visitCollation);
        if (input != input_new || collations != collations_new) {
            return this.operators.sort(input_new, collations_new);
        }
        return rel;
    }

    @NotNull
    public Collation visitCollation(@NotNull Collation collation, C ctx) {
        return collation;
    }

    @Override
    public Operator visitUnion(@NotNull RelUnion rel, C ctx) {
        Rel left = rel.getLeft();
        Rel left_new = this.visit(left, ctx, Rel.class);
        Rel right = rel.getRight();
        Rel right_new = this.visit(right, ctx, Rel.class);
        if (left != left_new || right != right_new) {
            return this.operators.union(left_new, right_new, rel.isAll());
        }
        return rel;
    }

    @Override
    public Operator visitUnpivot(@NotNull RelUnpivot rel, C ctx) {
        Rex rex_new;
        Rex rex = rel.getRex();
        if (rex != (rex_new = this.visit(rex, ctx, Rex.class))) {
            return this.operators.unpivot(rex);
        }
        return rel;
    }

    @Override
    public Operator visitArray(@NotNull RexArray rex, C ctx) {
        List<Rex> values_new;
        List<Rex> values = rex.getValues();
        if (values != (values_new = this.visitAll(values, ctx, this::visitRex))) {
            return this.operators.array(values_new);
        }
        return rex;
    }

    @Override
    public Operator visitBag(@NotNull RexBag rex, C ctx) {
        List<Rex> values_new;
        List<Rex> values = List.copyOf(rex.getValues());
        if (values != (values_new = this.visitAll(values, ctx, this::visitRex))) {
            return this.operators.bag(values_new);
        }
        return rex;
    }

    @Override
    public Operator visitCall(@NotNull RexCall rex, C ctx) {
        List<Rex> args_new;
        List<Rex> args = rex.getArgs();
        if (args != (args_new = this.visitAll(args, ctx, this::visitRex))) {
            return this.operators.call(rex.getFunction(), args_new);
        }
        return rex;
    }

    @Override
    public Operator visitCase(@NotNull RexCase rex, C ctx) {
        Rex default_new;
        Rex match = rex.getMatch();
        Rex match_new = match != null ? this.visit(match, ctx, Rex.class) : null;
        List<RexCase.Branch> branches = rex.getBranches();
        List<RexCase.Branch> branches_new = this.visitAll(branches, ctx, this::visitCaseBranch);
        Rex default_ = rex.getDefault();
        Rex rex2 = default_new = default_ != null ? this.visit(default_, ctx, Rex.class) : null;
        if (match != match_new || branches != branches_new || default_ != default_new) {
            return this.operators.caseWhen(match_new, branches_new, default_new);
        }
        return rex;
    }

    @NotNull
    public RexCase.Branch visitCaseBranch(@NotNull RexCase.Branch branch, C ctx) {
        Rex condition = branch.getCondition();
        Rex condition_new = this.visit(condition, ctx, Rex.class);
        Rex result = branch.getResult();
        Rex result_new = this.visit(result, ctx, Rex.class);
        if (condition != condition_new || result != result_new) {
            return RexCase.branch(condition_new, result_new);
        }
        return branch;
    }

    @Override
    public Operator visitCast(@NotNull RexCast rex, C ctx) {
        Rex operand_new;
        Rex operand = rex.getOperand();
        if (operand != (operand_new = this.visit(operand, ctx, Rex.class))) {
            return this.operators.cast(operand_new, rex.getTarget());
        }
        return rex;
    }

    @Override
    public Operator visitCoalesce(@NotNull RexCoalesce rex, C ctx) {
        List<Rex> args_new;
        List<Rex> args = rex.getArgs();
        if (args != (args_new = this.visitAll(args, ctx, this::visitRex))) {
            return this.operators.coalesce(args_new);
        }
        return rex;
    }

    @Override
    public Operator visitDispatch(@NotNull RexDispatch rex, C ctx) {
        List<Rex> args_new;
        List<Rex> args = rex.getArgs();
        if (args != (args_new = this.visitAll(args, ctx, this::visitRex))) {
            return this.operators.dispatch(rex.getName(), rex.getFunctions(), args_new);
        }
        return rex;
    }

    @Override
    public Operator visitError(@NotNull RexError rex, C ctx) {
        return rex;
    }

    @Override
    public Operator visitLit(@NotNull RexLit rex, C ctx) {
        return rex;
    }

    @Override
    public Operator visitNullIf(@NotNull RexNullIf rex, C ctx) {
        Rex v1 = rex.getV1();
        Rex v1_new = this.visit(v1, ctx, Rex.class);
        Rex v2 = rex.getV2();
        Rex v2_new = this.visit(v2, ctx, Rex.class);
        if (v1 != v1_new || v2 != v2_new) {
            return this.operators.nullIf(v1_new, v2_new);
        }
        return rex;
    }

    @Override
    public Operator visitPathIndex(@NotNull RexPathIndex rex, C ctx) {
        Rex operand = rex.getOperand();
        Rex operand_new = this.visit(operand, ctx, Rex.class);
        Rex index = rex.getIndex();
        Rex index_new = this.visit(index, ctx, Rex.class);
        if (operand != operand_new || index != index_new) {
            return this.operators.pathIndex(operand_new, index_new);
        }
        return rex;
    }

    @Override
    public Operator visitPathKey(@NotNull RexPathKey rex, C ctx) {
        Rex operand = rex.getOperand();
        Rex operand_new = this.visit(operand, ctx, Rex.class);
        Rex key = rex.getKey();
        Rex key_new = this.visit(key, ctx, Rex.class);
        if (operand != operand_new || key != key_new) {
            return this.operators.pathKey(operand_new, key_new);
        }
        return rex;
    }

    @Override
    public Operator visitPathSymbol(@NotNull RexPathSymbol rex, C ctx) {
        Rex operand_new;
        Rex operand = rex.getOperand();
        if (operand != (operand_new = this.visit(operand, ctx, Rex.class))) {
            return this.operators.pathSymbol(operand_new, rex.getSymbol());
        }
        return rex;
    }

    @Override
    public Operator visitPivot(@NotNull RexPivot rex, C ctx) {
        Rel input = rex.getInput();
        Rel input_new = this.visit(input, ctx, Rel.class);
        Rex key = rex.getKey();
        Rex key_new = this.visit(key, ctx, Rex.class);
        Rex value = rex.getValue();
        Rex value_new = this.visit(value, ctx, Rex.class);
        if (input != input_new || key != key_new || value != value_new) {
            return this.operators.pivot(input_new, key_new, value_new);
        }
        return rex;
    }

    @Override
    public Operator visitSelect(@NotNull RexSelect rex, C ctx) {
        Rel input = rex.getInput();
        Rel input_new = this.visit(input, ctx, Rel.class);
        Rex constructor = rex.getConstructor();
        Rex constructor_new = this.visit(constructor, ctx, Rex.class);
        if (input != input_new || constructor != constructor_new) {
            return this.operators.select(input_new, constructor_new);
        }
        return rex;
    }

    @Override
    public Operator visitStruct(@NotNull RexStruct rex, C ctx) {
        List<RexStruct.Field> fields_new;
        List<RexStruct.Field> fields = rex.getFields();
        if (fields != (fields_new = this.visitAll(fields, ctx, this::visitStructField))) {
            return this.operators.struct(fields_new);
        }
        return rex;
    }

    @NotNull
    public RexStruct.Field visitStructField(@NotNull RexStruct.Field field, C ctx) {
        Rex key = field.getKey();
        Rex key_new = this.visit(key, ctx, Rex.class);
        Rex value = field.getValue();
        Rex value_new = this.visit(value, ctx, Rex.class);
        if (key != key_new || value != value_new) {
            return RexStruct.field(key_new, value_new);
        }
        return field;
    }

    @Override
    public Operator visitSubquery(@NotNull RexSubquery rex, C ctx) {
        Rel input = rex.getInput();
        Rel input_new = this.visit(input, ctx, Rel.class);
        Rex constructor = rex.getConstructor();
        Rex constructor_new = this.visit(constructor, ctx, Rex.class);
        if (input != input_new || constructor != constructor_new) {
            return this.operators.subquery(input_new, constructor_new, rex.isScalar());
        }
        return rex;
    }

    @Override
    public Operator visitSubqueryComp(@NotNull RexSubqueryComp rex, C ctx) {
        Rel input = rex.getInput();
        Rel input_new = this.visit(input, ctx, Rel.class);
        List<Rex> args = rex.getArgs();
        List<Rex> args_new = this.visitAll(args, ctx, this::visitRex);
        if (input != input_new) {
            return this.operators.subqueryComp(input_new, args_new, rex.getComparison(), rex.getQuantifier());
        }
        return rex;
    }

    @Override
    public Operator visitSubqueryIn(@NotNull RexSubqueryIn rex, C ctx) {
        Rel input = rex.getInput();
        Rel input_new = this.visit(input, ctx, Rel.class);
        List<Rex> args = rex.getArgs();
        List<Rex> args_new = this.visitAll(args, ctx, this::visitRex);
        if (input != input_new) {
            return this.operators.subqueryIn(input_new, args_new);
        }
        return rex;
    }

    @Override
    public Operator visitSubqueryTest(@NotNull RexSubqueryTest rex, C ctx) {
        Rel input_new;
        Rel input = rex.getInput();
        if (input != (input_new = this.visit(input, ctx, Rel.class))) {
            return this.operators.subqueryTest(input_new, rex.getTest());
        }
        return rex;
    }

    @Override
    public Operator visitSpread(@NotNull RexSpread rex, C ctx) {
        List<Rex> args_new;
        List<Rex> args = rex.getArgs();
        if (args != (args_new = this.visitAll(args, ctx, this::visitRex))) {
            return this.operators.spread(args_new);
        }
        return rex;
    }

    @Override
    public Operator visitTable(@NotNull RexTable rex, C ctx) {
        return rex;
    }

    @Override
    public Operator visitVar(@NotNull RexVar rex, C ctx) {
        return rex;
    }

    @NotNull
    public final Rel visitRel(@NotNull Rel rel, C ctx) {
        return this.visit(rel, ctx, Rel.class);
    }

    @NotNull
    public final Rex visitRex(@NotNull Rex rex, C ctx) {
        return this.visit(rex, ctx, Rex.class);
    }

    @NotNull
    public final <T extends Operator> T visit(@NotNull Operator operator, C ctx, Class<T> clazz) {
        Operator o = (Operator)this.visit(operator, ctx);
        if (clazz.isInstance(o)) {
            return (T)((Operator)clazz.cast(o));
        }
        return this.onError(o, clazz);
    }

    @NotNull
    public final <T> List<T> visitAll(@NotNull List<T> objects, C ctx, @NotNull Mapper<T, C> mapper) {
        if (objects.isEmpty()) {
            return objects;
        }
        boolean diff = false;
        ArrayList<T> mapped = new ArrayList<T>(objects.size());
        for (T o : objects) {
            T t = mapper.apply(o, ctx);
            mapped.add(t);
            diff |= o != t;
        }
        return diff ? mapped : objects;
    }

    @NotNull
    public <T extends Operator> T onError(@NotNull Operator o, @NotNull Class<T> clazz) {
        throw new ClassCastException("OperatorRewriter expected " + clazz.getName() + ", found: " + o.getClass().getName());
    }

    public static interface Mapper<T, C> {
        public T apply(T var1, C var2);
    }
}

