/*
 * Decompiled with CFR 0.152.
 */
package io.substrait.relation;

import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.relation.Aggregate;
import io.substrait.relation.ConsistentPartitionWindow;
import io.substrait.relation.CopyOnWriteUtils;
import io.substrait.relation.Cross;
import io.substrait.relation.EmptyScan;
import io.substrait.relation.ExpressionCopyOnWriteVisitor;
import io.substrait.relation.ExtensionLeaf;
import io.substrait.relation.ExtensionMulti;
import io.substrait.relation.ExtensionSingle;
import io.substrait.relation.ExtensionTable;
import io.substrait.relation.Fetch;
import io.substrait.relation.Filter;
import io.substrait.relation.ImmutableJoin;
import io.substrait.relation.Join;
import io.substrait.relation.LocalFiles;
import io.substrait.relation.NamedScan;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import io.substrait.relation.RelVisitor;
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.VirtualTableScan;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.MergeJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;

public class RelCopyOnWriteVisitor<EXCEPTION extends Exception>
implements RelVisitor<Optional<Rel>, EXCEPTION> {
    private final ExpressionCopyOnWriteVisitor<EXCEPTION> expressionCopyOnWriteVisitor;

    public RelCopyOnWriteVisitor() {
        this.expressionCopyOnWriteVisitor = new ExpressionCopyOnWriteVisitor(this);
    }

    public RelCopyOnWriteVisitor(ExpressionCopyOnWriteVisitor<EXCEPTION> expressionCopyOnWriteVisitor) {
        this.expressionCopyOnWriteVisitor = expressionCopyOnWriteVisitor;
    }

    public RelCopyOnWriteVisitor(Function<RelCopyOnWriteVisitor<EXCEPTION>, ExpressionCopyOnWriteVisitor<EXCEPTION>> fn) {
        this.expressionCopyOnWriteVisitor = fn.apply(this);
    }

    protected ExpressionCopyOnWriteVisitor<EXCEPTION> getExpressionCopyOnWriteVisitor() {
        return this.expressionCopyOnWriteVisitor;
    }

    @Override
    public Optional<Rel> visit(Aggregate aggregate) throws EXCEPTION {
        Optional input = (Optional)aggregate.getInput().accept(this);
        Optional<List<Aggregate.Grouping>> groupings = CopyOnWriteUtils.transformList(aggregate.getGroupings(), this::visitGrouping);
        Optional<List<Aggregate.Measure>> measures = CopyOnWriteUtils.transformList(aggregate.getMeasures(), this::visitMeasure);
        if (CopyOnWriteUtils.allEmpty(input, groupings, measures)) {
            return Optional.empty();
        }
        return Optional.of(Aggregate.builder().from(aggregate).input(input.orElse(aggregate.getInput())).groupings((Iterable<? extends Aggregate.Grouping>)groupings.orElse(aggregate.getGroupings())).measures((Iterable<? extends Aggregate.Measure>)measures.orElse(aggregate.getMeasures())).build());
    }

    protected Optional<Aggregate.Grouping> visitGrouping(Aggregate.Grouping grouping) throws EXCEPTION {
        return this.visitExprList(grouping.getExpressions()).map(exprs -> Aggregate.Grouping.builder().from(grouping).expressions((Iterable<? extends Expression>)exprs).build());
    }

    protected Optional<Aggregate.Measure> visitMeasure(Aggregate.Measure measure) throws EXCEPTION {
        Optional<Expression> preMeasureFilter = this.visitOptionalExpression(measure.getPreMeasureFilter());
        Optional<AggregateFunctionInvocation> afi = this.visitAggregateFunction(measure.getFunction());
        if (CopyOnWriteUtils.allEmpty(preMeasureFilter, afi)) {
            return Optional.empty();
        }
        return Optional.of(Aggregate.Measure.builder().from(measure).preMeasureFilter(CopyOnWriteUtils.or(preMeasureFilter, measure::getPreMeasureFilter)).function(afi.orElse(measure.getFunction())).build());
    }

    protected Optional<AggregateFunctionInvocation> visitAggregateFunction(AggregateFunctionInvocation afi) throws EXCEPTION {
        Optional<List<FunctionArg>> arguments = this.visitFunctionArguments(afi.arguments());
        Optional<List<Expression.SortField>> sort = CopyOnWriteUtils.transformList(afi.sort(), this::visitSortField);
        if (CopyOnWriteUtils.allEmpty(arguments, sort)) {
            return Optional.empty();
        }
        return Optional.of(AggregateFunctionInvocation.builder().from(afi).arguments((Iterable<? extends FunctionArg>)arguments.orElse(afi.arguments())).sort((Iterable<? extends Expression.SortField>)sort.orElse(afi.sort())).build());
    }

    @Override
    public Optional<Rel> visit(EmptyScan emptyScan) throws EXCEPTION {
        Optional<Expression> filter = this.visitOptionalExpression(emptyScan.getFilter());
        if (CopyOnWriteUtils.allEmpty(filter)) {
            return Optional.empty();
        }
        return Optional.of(EmptyScan.builder().from(emptyScan).filter(filter.isPresent() ? filter : emptyScan.getFilter()).build());
    }

    @Override
    public Optional<Rel> visit(Fetch fetch) throws EXCEPTION {
        return ((Optional)fetch.getInput().accept(this)).map(input -> Fetch.builder().from(fetch).input((Rel)input).build());
    }

    @Override
    public Optional<Rel> visit(Filter filter) throws EXCEPTION {
        Optional input = (Optional)filter.getInput().accept(this);
        Optional condition = (Optional)filter.getCondition().accept(this.getExpressionCopyOnWriteVisitor());
        if (CopyOnWriteUtils.allEmpty(input, condition)) {
            return Optional.empty();
        }
        return Optional.of(Filter.builder().from(filter).input(input.orElse(filter.getInput())).condition(condition.orElse(filter.getCondition())).build());
    }

    @Override
    public Optional<Rel> visit(Join join) throws EXCEPTION {
        Optional left = (Optional)join.getLeft().accept(this);
        Optional right = (Optional)join.getRight().accept(this);
        Optional<Expression> condition = this.visitOptionalExpression(join.getCondition());
        Optional<Expression> postFilter = this.visitOptionalExpression(join.getPostJoinFilter());
        if (CopyOnWriteUtils.allEmpty(left, right, condition, postFilter)) {
            return Optional.empty();
        }
        return Optional.of(ImmutableJoin.builder().from(join).left(left.orElse(join.getLeft())).right(right.orElse(join.getRight())).condition(CopyOnWriteUtils.or(condition, join::getCondition)).postJoinFilter(CopyOnWriteUtils.or(postFilter, join::getPostJoinFilter)).build());
    }

    @Override
    public Optional<Rel> visit(Set set) throws EXCEPTION {
        return CopyOnWriteUtils.transformList(set.getInputs(), t -> (Optional)t.accept(this)).map(s -> Set.builder().from(set).inputs((Iterable<? extends Rel>)s).build());
    }

    @Override
    public Optional<Rel> visit(NamedScan namedScan) throws EXCEPTION {
        Optional<Expression> filter = this.visitOptionalExpression(namedScan.getFilter());
        if (CopyOnWriteUtils.allEmpty(filter)) {
            return Optional.empty();
        }
        return Optional.of(NamedScan.builder().from(namedScan).filter(CopyOnWriteUtils.or(filter, namedScan::getFilter)).build());
    }

    @Override
    public Optional<Rel> visit(LocalFiles localFiles) throws EXCEPTION {
        Optional<Expression> filter = this.visitOptionalExpression(localFiles.getFilter());
        if (CopyOnWriteUtils.allEmpty(filter)) {
            return Optional.empty();
        }
        return Optional.of(LocalFiles.builder().from(localFiles).filter(CopyOnWriteUtils.or(filter, localFiles::getFilter)).build());
    }

    @Override
    public Optional<Rel> visit(Project project) throws EXCEPTION {
        Optional input = (Optional)project.getInput().accept(this);
        Optional<List<Expression>> expressions = this.visitExprList(project.getExpressions());
        if (CopyOnWriteUtils.allEmpty(input, expressions)) {
            return Optional.empty();
        }
        return Optional.of(Project.builder().from(project).input(input.orElse(project.getInput())).expressions((Iterable<? extends Expression>)expressions.orElse(project.getExpressions())).build());
    }

    @Override
    public Optional<Rel> visit(Sort sort) throws EXCEPTION {
        Optional input = (Optional)sort.getInput().accept(this);
        Optional<List<Expression.SortField>> sortFields = CopyOnWriteUtils.transformList(sort.getSortFields(), this::visitSortField);
        if (CopyOnWriteUtils.allEmpty(input, sortFields)) {
            return Optional.empty();
        }
        return Optional.of(Sort.builder().from(sort).input(input.orElse(sort.getInput())).sortFields((Iterable<? extends Expression.SortField>)sortFields.orElse(sort.getSortFields())).build());
    }

    @Override
    public Optional<Rel> visit(Cross cross) throws EXCEPTION {
        Optional left = (Optional)cross.getLeft().accept(this);
        Optional right = (Optional)cross.getRight().accept(this);
        if (CopyOnWriteUtils.allEmpty(left, right)) {
            return Optional.empty();
        }
        return Optional.of(Cross.builder().from(cross).left(left.orElse(cross.getLeft())).right(right.orElse(cross.getRight())).build());
    }

    @Override
    public Optional<Rel> visit(VirtualTableScan virtualTableScan) throws EXCEPTION {
        Optional<Expression> filter = this.visitOptionalExpression(virtualTableScan.getFilter());
        if (CopyOnWriteUtils.allEmpty(filter)) {
            return Optional.empty();
        }
        return Optional.of(VirtualTableScan.builder().from(virtualTableScan).filter(CopyOnWriteUtils.or(filter, virtualTableScan::getFilter)).build());
    }

    @Override
    public Optional<Rel> visit(ExtensionLeaf extensionLeaf) throws EXCEPTION {
        return Optional.empty();
    }

    @Override
    public Optional<Rel> visit(ExtensionSingle extensionSingle) throws EXCEPTION {
        return ((Optional)extensionSingle.getInput().accept(this)).map(input -> ExtensionSingle.builder().from(extensionSingle).input((Rel)input).build());
    }

    @Override
    public Optional<Rel> visit(ExtensionMulti extensionMulti) throws EXCEPTION {
        return CopyOnWriteUtils.transformList(extensionMulti.getInputs(), rel -> (Optional)rel.accept(this)).map(inputs -> ExtensionMulti.builder().from(extensionMulti).inputs((Iterable<? extends Rel>)inputs).build());
    }

    @Override
    public Optional<Rel> visit(ExtensionTable extensionTable) throws EXCEPTION {
        Optional<Expression> filter = this.visitOptionalExpression(extensionTable.getFilter());
        if (CopyOnWriteUtils.allEmpty(filter)) {
            return Optional.empty();
        }
        return Optional.of(ExtensionTable.builder().from(extensionTable).filter(CopyOnWriteUtils.or(filter, extensionTable::getFilter)).build());
    }

    @Override
    public Optional<Rel> visit(HashJoin hashJoin) throws EXCEPTION {
        Optional left = (Optional)hashJoin.getLeft().accept(this);
        Optional right = (Optional)hashJoin.getRight().accept(this);
        Optional<List<FieldReference>> leftKeys = CopyOnWriteUtils.transformList(hashJoin.getLeftKeys(), this::visitFieldReference);
        Optional<List<FieldReference>> rightKeys = CopyOnWriteUtils.transformList(hashJoin.getRightKeys(), this::visitFieldReference);
        Optional<Expression> postFilter = this.visitOptionalExpression(hashJoin.getPostJoinFilter());
        if (CopyOnWriteUtils.allEmpty(left, right, leftKeys, rightKeys, postFilter)) {
            return Optional.empty();
        }
        return Optional.of(HashJoin.builder().from(hashJoin).left(left.orElse(hashJoin.getLeft())).right(right.orElse(hashJoin.getRight())).leftKeys((Iterable<? extends FieldReference>)leftKeys.orElse(hashJoin.getLeftKeys())).rightKeys((Iterable<? extends FieldReference>)rightKeys.orElse(hashJoin.getRightKeys())).postJoinFilter(CopyOnWriteUtils.or(postFilter, hashJoin::getPostJoinFilter)).build());
    }

    @Override
    public Optional<Rel> visit(MergeJoin mergeJoin) throws EXCEPTION {
        Optional left = (Optional)mergeJoin.getLeft().accept(this);
        Optional right = (Optional)mergeJoin.getRight().accept(this);
        Optional<List<FieldReference>> leftKeys = CopyOnWriteUtils.transformList(mergeJoin.getLeftKeys(), this::visitFieldReference);
        Optional<List<FieldReference>> rightKeys = CopyOnWriteUtils.transformList(mergeJoin.getRightKeys(), this::visitFieldReference);
        Optional<Expression> postFilter = this.visitOptionalExpression(mergeJoin.getPostJoinFilter());
        if (CopyOnWriteUtils.allEmpty(left, right, leftKeys, rightKeys, postFilter)) {
            return Optional.empty();
        }
        return Optional.of(MergeJoin.builder().from(mergeJoin).left(left.orElse(mergeJoin.getLeft())).right(right.orElse(mergeJoin.getRight())).leftKeys((Iterable<? extends FieldReference>)leftKeys.orElse(mergeJoin.getLeftKeys())).rightKeys((Iterable<? extends FieldReference>)rightKeys.orElse(mergeJoin.getRightKeys())).postJoinFilter(CopyOnWriteUtils.or(postFilter, mergeJoin::getPostJoinFilter)).build());
    }

    @Override
    public Optional<Rel> visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION {
        Optional left = (Optional)nestedLoopJoin.getLeft().accept(this);
        Optional right = (Optional)nestedLoopJoin.getRight().accept(this);
        Optional condition = (Optional)nestedLoopJoin.getCondition().accept(this.getExpressionCopyOnWriteVisitor());
        if (CopyOnWriteUtils.allEmpty(left, right, condition)) {
            return Optional.empty();
        }
        return Optional.of(NestedLoopJoin.builder().from(nestedLoopJoin).left(left.orElse(nestedLoopJoin.getLeft())).right(right.orElse(nestedLoopJoin.getRight())).condition(condition.orElse(nestedLoopJoin.getCondition())).build());
    }

    @Override
    public Optional<Rel> visit(ConsistentPartitionWindow consistentPartitionWindow) throws EXCEPTION {
        Optional<List<ConsistentPartitionWindow.WindowRelFunctionInvocation>> windowFunctions = CopyOnWriteUtils.transformList(consistentPartitionWindow.getWindowFunctions(), this::visitWindowRelFunction);
        Optional<List<Expression>> partitionExpressions = CopyOnWriteUtils.transformList(consistentPartitionWindow.getPartitionExpressions(), t -> (Optional)t.accept(this.getExpressionCopyOnWriteVisitor()));
        Optional<List<Expression.SortField>> sorts = CopyOnWriteUtils.transformList(consistentPartitionWindow.getSorts(), this::visitSortField);
        if (CopyOnWriteUtils.allEmpty(windowFunctions, partitionExpressions, sorts)) {
            return Optional.empty();
        }
        return Optional.of(ConsistentPartitionWindow.builder().from(consistentPartitionWindow).partitionExpressions((Iterable<? extends Expression>)partitionExpressions.orElse(consistentPartitionWindow.getPartitionExpressions())).sorts((Iterable<? extends Expression.SortField>)sorts.orElse(consistentPartitionWindow.getSorts())).windowFunctions((Iterable<? extends ConsistentPartitionWindow.WindowRelFunctionInvocation>)windowFunctions.orElse(consistentPartitionWindow.getWindowFunctions())).build());
    }

    protected Optional<ConsistentPartitionWindow.WindowRelFunctionInvocation> visitWindowRelFunction(ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFunctionInvocation) throws EXCEPTION {
        Optional<List<FunctionArg>> functionArgs = this.visitFunctionArguments(windowRelFunctionInvocation.arguments());
        if (CopyOnWriteUtils.allEmpty(functionArgs)) {
            return Optional.empty();
        }
        return Optional.of(ConsistentPartitionWindow.WindowRelFunctionInvocation.builder().from(windowRelFunctionInvocation).arguments((Iterable<? extends FunctionArg>)functionArgs.orElse(windowRelFunctionInvocation.arguments())).build());
    }

    protected Optional<List<Expression>> visitExprList(List<Expression> exprs) throws EXCEPTION {
        return CopyOnWriteUtils.transformList(exprs, t -> (Optional)t.accept(this.getExpressionCopyOnWriteVisitor()));
    }

    public Optional<FieldReference> visitFieldReference(FieldReference fieldReference) throws EXCEPTION {
        Optional<Expression> inputExpression = this.visitOptionalExpression(fieldReference.inputExpression());
        if (CopyOnWriteUtils.allEmpty(inputExpression)) {
            return Optional.empty();
        }
        return Optional.of(FieldReference.builder().inputExpression(inputExpression).build());
    }

    protected Optional<List<FunctionArg>> visitFunctionArguments(List<FunctionArg> funcArgs) throws EXCEPTION {
        return CopyOnWriteUtils.transformList(funcArgs, arg -> {
            if (arg instanceof Expression) {
                Expression expr = (Expression)arg;
                return ((Optional)expr.accept(this.getExpressionCopyOnWriteVisitor())).flatMap(Optional::of);
            }
            return Optional.empty();
        });
    }

    protected Optional<Expression.SortField> visitSortField(Expression.SortField sortField) throws EXCEPTION {
        return ((Optional)sortField.expr().accept(this.getExpressionCopyOnWriteVisitor())).map(expr -> Expression.SortField.builder().from(sortField).expr((Expression)expr).build());
    }

    private Optional<Expression> visitOptionalExpression(Optional<Expression> optExpr) throws EXCEPTION {
        if (optExpr.isPresent()) {
            return (Optional)optExpr.get().accept(this.getExpressionCopyOnWriteVisitor());
        }
        return Optional.empty();
    }
}

