/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.base.MoreObjects;
import com.google.common.collect.ImmutableList;
import io.trino.sql.ir.Expression;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.DistinctLimitNode;
import io.trino.sql.planner.plan.LimitNode;
import io.trino.sql.planner.plan.MarkDistinctNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import io.trino.sql.planner.plan.SortNode;
import io.trino.sql.planner.plan.TopNNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

public class LimitPushDown
implements PlanOptimizer {
    @Override
    public PlanNode optimize(PlanNode plan, PlanOptimizer.Context context) {
        Objects.requireNonNull(plan, "plan is null");
        return SimplePlanRewriter.rewriteWith(new Rewriter(context.idAllocator()), plan, null);
    }

    private static class Rewriter
    extends SimplePlanRewriter<LimitContext> {
        private final PlanNodeIdAllocator idAllocator;

        private Rewriter(PlanNodeIdAllocator idAllocator) {
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
        }

        @Override
        public PlanNode visitPlan(PlanNode node, SimplePlanRewriter.RewriteContext<LimitContext> context) {
            PlanNode rewrittenNode = context.defaultRewrite(node);
            LimitContext limit = context.get();
            if (limit != null) {
                rewrittenNode = new LimitNode(this.idAllocator.getNextId(), rewrittenNode, limit.getCount(), limit.isPartial());
            }
            return rewrittenNode;
        }

        @Override
        public PlanNode visitLimit(LimitNode node, SimplePlanRewriter.RewriteContext<LimitContext> context) {
            long count = node.getCount();
            LimitContext limit = context.get();
            if (limit != null) {
                count = Math.min(count, limit.getCount());
            }
            if (count == 0L) {
                return new ValuesNode(this.idAllocator.getNextId(), node.getOutputSymbols(), (List<Expression>)ImmutableList.of());
            }
            if (!node.requiresPreSortedInputs() && (!node.isWithTies() || limit != null && node.getCount() >= limit.getCount())) {
                boolean partial = node.isPartial() && (limit == null || limit.isPartial());
                return context.rewrite(node.getSource(), new LimitContext(count, partial));
            }
            return context.defaultRewrite(node, context.get());
        }

        @Override
        @Deprecated
        public PlanNode visitAggregation(AggregationNode node, SimplePlanRewriter.RewriteContext<LimitContext> context) {
            LimitContext limit = context.get();
            if (limit != null && node.getAggregations().isEmpty() && !node.getGroupingKeys().isEmpty() && node.getOutputSymbols().size() == node.getGroupingKeys().size() && node.getOutputSymbols().containsAll(node.getGroupingKeys())) {
                PlanNode rewrittenSource = context.rewrite(node.getSource());
                return new DistinctLimitNode(this.idAllocator.getNextId(), rewrittenSource, limit.getCount(), false, rewrittenSource.getOutputSymbols(), Optional.empty());
            }
            PlanNode rewrittenNode = context.defaultRewrite(node);
            if (limit != null) {
                rewrittenNode = new LimitNode(this.idAllocator.getNextId(), rewrittenNode, limit.getCount(), limit.isPartial());
            }
            return rewrittenNode;
        }

        @Override
        public PlanNode visitMarkDistinct(MarkDistinctNode node, SimplePlanRewriter.RewriteContext<LimitContext> context) {
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitProject(ProjectNode node, SimplePlanRewriter.RewriteContext<LimitContext> context) {
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitTopN(TopNNode node, SimplePlanRewriter.RewriteContext<LimitContext> context) {
            LimitContext limit = context.get();
            PlanNode rewrittenSource = context.rewrite(node.getSource());
            if (rewrittenSource == node.getSource() && limit == null) {
                return node;
            }
            long count = node.getCount();
            if (limit != null) {
                count = Math.min(count, limit.getCount());
            }
            return new TopNNode(node.getId(), rewrittenSource, count, node.getOrderingScheme(), node.getStep());
        }

        @Override
        @Deprecated
        public PlanNode visitSort(SortNode node, SimplePlanRewriter.RewriteContext<LimitContext> context) {
            LimitContext limit = context.get();
            PlanNode rewrittenSource = context.rewrite(node.getSource());
            if (limit != null) {
                return new TopNNode(node.getId(), rewrittenSource, limit.getCount(), node.getOrderingScheme(), TopNNode.Step.SINGLE);
            }
            if (rewrittenSource != node.getSource()) {
                return new SortNode(node.getId(), rewrittenSource, node.getOrderingScheme(), node.isPartial());
            }
            return node;
        }

        @Override
        public PlanNode visitUnion(UnionNode node, SimplePlanRewriter.RewriteContext<LimitContext> context) {
            LimitContext limit = context.get();
            LimitContext childLimit = null;
            if (limit != null) {
                childLimit = new LimitContext(limit.getCount(), true);
            }
            ArrayList<PlanNode> sources = new ArrayList<PlanNode>();
            for (int i = 0; i < node.getSources().size(); ++i) {
                sources.add(context.rewrite(node.getSources().get(i), childLimit));
            }
            PlanNode output = new UnionNode(node.getId(), sources, node.getSymbolMapping(), node.getOutputSymbols());
            if (limit != null) {
                output = new LimitNode(this.idAllocator.getNextId(), output, limit.getCount(), limit.isPartial());
            }
            return output;
        }

        @Override
        public PlanNode visitSemiJoin(SemiJoinNode node, SimplePlanRewriter.RewriteContext<LimitContext> context) {
            PlanNode source = context.rewrite(node.getSource(), context.get());
            if (source != node.getSource()) {
                return new SemiJoinNode(node.getId(), source, node.getFilteringSource(), node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol(), node.getSemiJoinOutput(), node.getSourceHashSymbol(), node.getFilteringSourceHashSymbol(), node.getDistributionType(), node.getDynamicFilterId());
            }
            return node;
        }
    }

    private static class LimitContext {
        private final long count;
        private final boolean partial;

        public LimitContext(long count, boolean partial) {
            this.count = count;
            this.partial = partial;
        }

        public long getCount() {
            return this.count;
        }

        public boolean isPartial() {
            return this.partial;
        }

        public String toString() {
            return MoreObjects.toStringHelper((Object)this).add("count", this.count).add("partial", this.partial).toString();
        }
    }
}

