/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.cost;

import io.prestosql.Session;
import io.prestosql.cost.PlanNodeStatsEstimate;
import io.prestosql.cost.SimpleStatsRule;
import io.prestosql.cost.StatsNormalizer;
import io.prestosql.cost.StatsProvider;
import io.prestosql.cost.SymbolStatsEstimate;
import io.prestosql.matching.Pattern;
import io.prestosql.spi.block.SortOrder;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.TopNNode;
import java.util.Optional;

public class TopNStatsRule
extends SimpleStatsRule<TopNNode> {
    private static final Pattern<TopNNode> PATTERN = Patterns.topN();

    public TopNStatsRule(StatsNormalizer normalizer) {
        super(normalizer);
    }

    @Override
    public Pattern<TopNNode> getPattern() {
        return PATTERN;
    }

    @Override
    protected Optional<PlanNodeStatsEstimate> doCalculate(TopNNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) {
        PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource());
        double rowCount = sourceStats.getOutputRowCount();
        if (node.getStep() != TopNNode.Step.SINGLE) {
            return Optional.empty();
        }
        if (rowCount <= (double)node.getCount()) {
            return Optional.of(sourceStats);
        }
        long limitCount = node.getCount();
        PlanNodeStatsEstimate resultStats = PlanNodeStatsEstimate.buildFrom(sourceStats).setOutputRowCount(limitCount).build();
        if (limitCount == 0L) {
            return Optional.of(resultStats);
        }
        Symbol firstOrderSymbol = node.getOrderingScheme().getOrderBy().get(0);
        SortOrder sortOrder = node.getOrderingScheme().getOrdering(firstOrderSymbol);
        resultStats = resultStats.mapSymbolColumnStatistics(firstOrderSymbol, symbolStats -> {
            SymbolStatsEstimate.Builder newStats = SymbolStatsEstimate.buildFrom(symbolStats);
            double nullCount = rowCount * symbolStats.getNullsFraction();
            if (sortOrder.isNullsFirst()) {
                if (nullCount > (double)limitCount) {
                    newStats.setNullsFraction(1.0);
                } else {
                    newStats.setNullsFraction(nullCount / (double)limitCount);
                }
            } else {
                double nonNullCount = rowCount - nullCount;
                if (nonNullCount > (double)limitCount) {
                    newStats.setNullsFraction(0.0);
                } else {
                    newStats.setNullsFraction(((double)limitCount - nonNullCount) / (double)limitCount);
                }
            }
            return newStats.build();
        });
        return Optional.of(resultStats);
    }
}

