/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SimpleStatsRule;
import io.trino.cost.StatsCalculator;
import io.trino.cost.StatsNormalizer;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.matching.Pattern;
import io.trino.spi.connector.SortOrder;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.TopNNode;
import java.util.Optional;

public class TopNStatsRule
extends SimpleStatsRule<TopNNode> {
    private static final int ESTIMATED_PARTIAL_TOPN_INPUT_PER_DRIVER = 1000000;
    private static final Pattern<TopNNode> PATTERN = Patterns.topN();

    public TopNStatsRule(StatsNormalizer normalizer) {
        super(normalizer);
    }

    @Override
    public Pattern<TopNNode> getPattern() {
        return PATTERN;
    }

    @Override
    protected Optional<PlanNodeStatsEstimate> doCalculate(TopNNode node, StatsCalculator.Context context) {
        PlanNodeStatsEstimate sourceStats = context.statsProvider().getStats(node.getSource());
        double rowCount = sourceStats.getOutputRowCount();
        if (node.getStep() == TopNNode.Step.PARTIAL) {
            double estimatedOutputRowCount = Math.max(rowCount / 1000000.0, 1.0) * (double)node.getCount();
            return Optional.of(PlanNodeStatsEstimate.buildFrom(sourceStats).setOutputRowCount(Math.min(estimatedOutputRowCount, rowCount)).build());
        }
        if (rowCount <= (double)node.getCount()) {
            return Optional.of(sourceStats);
        }
        long limitCount = node.getCount();
        PlanNodeStatsEstimate resultStats = PlanNodeStatsEstimate.buildFrom(sourceStats).setOutputRowCount(limitCount).build();
        if (limitCount == 0L) {
            return Optional.of(resultStats);
        }
        Symbol firstOrderSymbol = node.getOrderingScheme().getOrderBy().get(0);
        SortOrder sortOrder = node.getOrderingScheme().getOrdering(firstOrderSymbol);
        resultStats = resultStats.mapSymbolColumnStatistics(firstOrderSymbol, symbolStats -> {
            SymbolStatsEstimate.Builder newStats = SymbolStatsEstimate.buildFrom(symbolStats);
            double nullCount = rowCount * symbolStats.getNullsFraction();
            if (sortOrder.isNullsFirst()) {
                if (nullCount > (double)limitCount) {
                    newStats.setNullsFraction(1.0);
                } else {
                    newStats.setNullsFraction(nullCount / (double)limitCount);
                }
            } else {
                double nonNullCount = rowCount - nullCount;
                if (nonNullCount > (double)limitCount) {
                    newStats.setNullsFraction(0.0);
                } else {
                    newStats.setNullsFraction(((double)limitCount - nonNullCount) / (double)limitCount);
                }
            }
            return newStats.build();
        });
        return Optional.of(resultStats);
    }
}

