/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import com.google.common.collect.ImmutableMap;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SimpleStatsRule;
import io.trino.cost.StatsCalculator;
import io.trino.cost.StatsNormalizer;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.matching.Pattern;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Patterns;
import java.util.Collection;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class AggregationStatsRule
extends SimpleStatsRule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation();

    public AggregationStatsRule(StatsNormalizer normalizer) {
        super(normalizer);
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    protected Optional<PlanNodeStatsEstimate> doCalculate(AggregationNode node, StatsCalculator.Context context) {
        if (node.getGroupingSetCount() != 1) {
            return Optional.empty();
        }
        PlanNodeStatsEstimate estimate = node.getStep() == AggregationNode.Step.PARTIAL || node.getStep() == AggregationNode.Step.INTERMEDIATE ? AggregationStatsRule.partialGroupBy(context.statsProvider().getStats(node.getSource()), node.getGroupingKeys(), node.getAggregations()) : AggregationStatsRule.groupBy(context.statsProvider().getStats(node.getSource()), node.getGroupingKeys(), node.getAggregations());
        return Optional.of(estimate);
    }

    public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate sourceStats, Collection<Symbol> groupBySymbols, Map<Symbol, AggregationNode.Aggregation> aggregations) {
        PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
        if (groupBySymbols.isEmpty()) {
            result.setOutputRowCount(1.0);
        } else {
            result.addSymbolStatistics(AggregationStatsRule.getGroupBySymbolsStatistics(sourceStats, groupBySymbols));
            double rowsCount = AggregationStatsRule.getRowsCount(sourceStats, groupBySymbols);
            result.setOutputRowCount(Math.min(rowsCount, sourceStats.getOutputRowCount()));
        }
        for (Map.Entry<Symbol, AggregationNode.Aggregation> aggregationEntry : aggregations.entrySet()) {
            result.addSymbolStatistics(aggregationEntry.getKey(), AggregationStatsRule.estimateAggregationStats(aggregationEntry.getValue(), sourceStats));
        }
        return result.build();
    }

    public static double getRowsCount(PlanNodeStatsEstimate sourceStats, Collection<Symbol> groupBySymbols) {
        double rowsCount = 1.0;
        for (Symbol groupBySymbol : groupBySymbols) {
            SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(groupBySymbol);
            boolean nullRow = symbolStatistics.getNullsFraction() != 0.0;
            rowsCount *= symbolStatistics.getDistinctValuesCount() + (double)nullRow;
        }
        return rowsCount;
    }

    private static PlanNodeStatsEstimate partialGroupBy(PlanNodeStatsEstimate sourceStats, Collection<Symbol> groupBySymbols, Map<Symbol, AggregationNode.Aggregation> aggregations) {
        PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
        result.setOutputRowCount(sourceStats.getOutputRowCount());
        result.addSymbolStatistics(AggregationStatsRule.getGroupBySymbolsStatistics(sourceStats, groupBySymbols));
        for (Map.Entry<Symbol, AggregationNode.Aggregation> aggregationEntry : aggregations.entrySet()) {
            result.addSymbolStatistics(aggregationEntry.getKey(), AggregationStatsRule.estimateAggregationStats(aggregationEntry.getValue(), sourceStats));
        }
        return result.build();
    }

    private static Map<Symbol, SymbolStatsEstimate> getGroupBySymbolsStatistics(PlanNodeStatsEstimate sourceStats, Collection<Symbol> groupBySymbols) {
        ImmutableMap.Builder symbolSymbolStatsEstimates = ImmutableMap.builder();
        for (Symbol groupBySymbol : groupBySymbols) {
            SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(groupBySymbol);
            symbolSymbolStatsEstimates.put((Object)groupBySymbol, (Object)symbolStatistics.mapNullsFraction(nullsFraction -> {
                if (nullsFraction == 0.0) {
                    return 0.0;
                }
                return 1.0 / (symbolStatistics.getDistinctValuesCount() + 1.0);
            }));
        }
        return symbolSymbolStatsEstimates.buildOrThrow();
    }

    private static SymbolStatsEstimate estimateAggregationStats(AggregationNode.Aggregation aggregation, PlanNodeStatsEstimate sourceStats) {
        Objects.requireNonNull(aggregation, "aggregation is null");
        Objects.requireNonNull(sourceStats, "sourceStats is null");
        return SymbolStatsEstimate.unknown();
    }
}

