/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.DereferencePushdown;
import io.trino.sql.planner.optimizations.SymbolMapper;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.TopNNode;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.Set;

public final class PushTopNThroughProject
implements Rule<TopNNode> {
    private static final Capture<ProjectNode> PROJECT_CHILD = Capture.newCapture();
    private static final Pattern<TopNNode> PATTERN = Patterns.topN().with(Patterns.source().matching(Patterns.project().matching(projectNode -> !projectNode.isIdentity()).capturedAs(PROJECT_CHILD).with(Patterns.source().matching(node -> !(node instanceof TableScanNode)))));

    @Override
    public Pattern<TopNNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(TopNNode parent, Captures captures, Rule.Context context) {
        PlanNode filterSource;
        ProjectNode projectNode = (ProjectNode)captures.get(PROJECT_CHILD);
        ImmutableSet projections = ImmutableSet.copyOf(projectNode.getAssignments().getExpressions());
        if (!DereferencePushdown.extractRowSubscripts((Collection<Expression>)projections, false).isEmpty() && DereferencePushdown.exclusiveDereferences((Set<Expression>)projections)) {
            return Rule.Result.empty();
        }
        PlanNode projectSource = context.getLookup().resolve(projectNode.getSource());
        if (projectSource instanceof FilterNode && (filterSource = context.getLookup().resolve(((FilterNode)projectSource).getSource())) instanceof TableScanNode) {
            return Rule.Result.empty();
        }
        Optional<SymbolMapper> symbolMapper = this.symbolMapper(parent.getOrderingScheme().getOrderBy(), projectNode.getAssignments());
        if (symbolMapper.isEmpty()) {
            return Rule.Result.empty();
        }
        TopNNode mappedTopN = symbolMapper.get().map(parent, projectNode.getSource(), context.getIdAllocator().getNextId());
        return Rule.Result.ofPlanNode(projectNode.replaceChildren((List<PlanNode>)ImmutableList.of((Object)mappedTopN)));
    }

    private Optional<SymbolMapper> symbolMapper(List<Symbol> symbols, Assignments assignments) {
        SymbolMapper.Builder mapper = SymbolMapper.builder();
        for (Symbol symbol : symbols) {
            Expression expression = assignments.get(symbol);
            if (!(expression instanceof SymbolReference)) {
                return Optional.empty();
            }
            mapper.put(symbol, Symbol.from(expression));
        }
        return Optional.of(mapper.build());
    }
}

