/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import com.google.common.base.Verify;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SimpleStatsRule;
import io.trino.cost.StatsCalculator;
import io.trino.cost.StatsNormalizer;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.matching.Pattern;
import io.trino.spi.type.BigintType;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.TopNRankingNode;
import java.util.List;
import java.util.Optional;

public class TopNRankingStatsRule
extends SimpleStatsRule<TopNRankingNode> {
    private static final double INDEPENDENCE_FACTOR = 0.9;
    private static final Pattern<TopNRankingNode> PATTERN = Patterns.topNRanking();

    public TopNRankingStatsRule(StatsNormalizer normalizer) {
        super(normalizer);
    }

    @Override
    public Pattern<TopNRankingNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Optional<PlanNodeStatsEstimate> doCalculate(TopNRankingNode node, StatsCalculator.Context context) {
        double rankDistinctCount;
        double topRowsPerPartition;
        PlanNodeStatsEstimate sourceStats = context.statsProvider().getStats(node.getSource());
        if (sourceStats.isOutputRowCountUnknown()) {
            return Optional.empty();
        }
        if (node.isPartial()) {
            return Optional.of(sourceStats);
        }
        double sourceRowCount = sourceStats.getOutputRowCount();
        if (sourceRowCount == 0.0) {
            return Optional.of(PlanNodeStatsEstimate.buildFrom(sourceStats).addSymbolStatistics(node.getRankingSymbol(), SymbolStatsEstimate.zero()).build());
        }
        double partitionCount = TopNRankingStatsRule.estimateCorrelatedPartitionCount(node.getPartitionBy(), sourceStats);
        if (Double.isNaN(partitionCount = Math.min(sourceRowCount, partitionCount))) {
            return Optional.of(sourceStats);
        }
        double averageRowsPerPartition = sourceRowCount / partitionCount;
        if (node.getRankingType() == TopNRankingNode.RankingType.ROW_NUMBER) {
            rankDistinctCount = topRowsPerPartition = Math.min(averageRowsPerPartition, (double)node.getMaxRankingPerPartition());
        } else if (node.getRankingType() == TopNRankingNode.RankingType.RANK) {
            distinctCount = Math.min(TopNRankingStatsRule.estimateCorrelatedPartitionCount(node.getOrderingScheme().orderBy(), sourceStats), averageRowsPerPartition);
            double rowsPerDistinctValue = averageRowsPerPartition / distinctCount;
            rankDistinctCount = Math.ceil((double)node.getMaxRankingPerPartition() / rowsPerDistinctValue);
            topRowsPerPartition = Math.min(averageRowsPerPartition, rankDistinctCount * rowsPerDistinctValue);
        } else {
            distinctCount = Math.min(TopNRankingStatsRule.estimateCorrelatedPartitionCount(node.getOrderingScheme().orderBy(), sourceStats), averageRowsPerPartition);
            double rowsPerDistinctValue = averageRowsPerPartition / distinctCount;
            topRowsPerPartition = Math.min(averageRowsPerPartition, (double)node.getMaxRankingPerPartition() * rowsPerDistinctValue);
            rankDistinctCount = Math.min(distinctCount, (double)node.getMaxRankingPerPartition());
        }
        if (Double.isNaN(topRowsPerPartition)) {
            topRowsPerPartition = averageRowsPerPartition;
        }
        if (Double.isNaN(rankDistinctCount)) {
            rankDistinctCount = topRowsPerPartition;
        }
        PlanNodeStatsEstimate.Builder estimateBuilder = PlanNodeStatsEstimate.buildFrom(sourceStats);
        TopNRankingStatsRule.adjustOrderBySymbolDistinctCount(node.getOrderingScheme().orderBy(), partitionCount, topRowsPerPartition, averageRowsPerPartition, sourceStats, estimateBuilder);
        double outputRowsCount = partitionCount * topRowsPerPartition;
        return Optional.of(estimateBuilder.setOutputRowCount(outputRowsCount).addSymbolStatistics(node.getRankingSymbol(), SymbolStatsEstimate.builder().setLowValue(1.0).setHighValue(node.getMaxRankingPerPartition()).setDistinctValuesCount(rankDistinctCount).setNullsFraction(0.0).setAverageRowSize(BigintType.BIGINT.getFixedSize()).build()).build());
    }

    private static void adjustOrderBySymbolDistinctCount(List<Symbol> orderBy, double partitionCount, double topRowsPerPartition, double averageRowsPerPartition, PlanNodeStatsEstimate sourceStats, PlanNodeStatsEstimate.Builder estimateBuilder) {
        Verify.verify((!orderBy.isEmpty() ? 1 : 0) != 0, (String)"Order by symbols should not be empty for TopNRankingNode.", (Object[])new Object[0]);
        Symbol firstSortSymbol = orderBy.getFirst();
        SymbolStatsEstimate symbolStats = sourceStats.getSymbolStatistics(firstSortSymbol);
        double distinctCountPerPartition = Math.min(symbolStats.getDistinctValuesCount(), averageRowsPerPartition);
        double adjustedDistinctCountPerPartition = Math.ceil(distinctCountPerPartition / averageRowsPerPartition * topRowsPerPartition);
        if (Double.isNaN(adjustedDistinctCountPerPartition)) {
            return;
        }
        double newDistinctCount = Math.min(symbolStats.getDistinctValuesCount(), adjustedDistinctCountPerPartition * partitionCount);
        SymbolStatsEstimate newFirstSortSymbolStats = SymbolStatsEstimate.buildFrom(symbolStats).setDistinctValuesCount(newDistinctCount).build();
        estimateBuilder.addSymbolStatistics(firstSortSymbol, newFirstSortSymbolStats);
    }

    private static double estimateCorrelatedPartitionCount(List<Symbol> partitionBy, PlanNodeStatsEstimate sourceStats) {
        double distinctCount = 1.0;
        double combinedIndependenceFactor = 1.0;
        for (Symbol partitionSymbol : partitionBy) {
            SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(partitionSymbol);
            boolean nullRow = symbolStatistics.getNullsFraction() != 0.0;
            distinctCount *= Math.pow(symbolStatistics.getDistinctValuesCount() + (double)nullRow, combinedIndependenceFactor);
            combinedIndependenceFactor *= 0.9;
        }
        return distinctCount;
    }
}

