/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatsProvider;
import io.trino.cost.TaskCountEstimator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.AggregationNode;
import java.util.Objects;

public class DistinctAggregationStrategyChooser {
    private static final int MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER = 8;
    private static final int PRE_AGGREGATE_MAX_OUTPUT_ROW_COUNT_MULTIPLIER = 64;
    private final TaskCountEstimator taskCountEstimator;

    private DistinctAggregationStrategyChooser(TaskCountEstimator taskCountEstimator) {
        this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
    }

    public static DistinctAggregationStrategyChooser createDistinctAggregationStrategyChooser(TaskCountEstimator taskCountEstimator) {
        return new DistinctAggregationStrategyChooser(taskCountEstimator);
    }

    public boolean shouldAddMarkDistinct(AggregationNode aggregationNode, Session session, StatsProvider statsProvider) {
        return !this.canParallelizeSingleStepDistinctAggregation(aggregationNode, session, statsProvider, 8);
    }

    public boolean shouldUsePreAggregate(AggregationNode aggregationNode, Session session, StatsProvider statsProvider) {
        if (this.canParallelizeSingleStepDistinctAggregation(aggregationNode, session, statsProvider, 64)) {
            return false;
        }
        return aggregationNode.getGroupingKeys().size() <= 2;
    }

    private boolean canParallelizeSingleStepDistinctAggregation(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, int maxOutputRowCountMultiplier) {
        if (aggregationNode.getGroupingKeys().isEmpty()) {
            return false;
        }
        double numberOfDistinctValues = this.getMinDistinctValueCountEstimate(aggregationNode, statsProvider);
        if (Double.isNaN(numberOfDistinctValues)) {
            return false;
        }
        int maxNumberOfConcurrentThreadsForAggregation = this.getMaxNumberOfConcurrentThreadsForAggregation(session);
        return !(numberOfDistinctValues <= (double)(maxOutputRowCountMultiplier * maxNumberOfConcurrentThreadsForAggregation));
    }

    private int getMaxNumberOfConcurrentThreadsForAggregation(Session session) {
        return this.taskCountEstimator.estimateHashedTaskCount(session) * SystemSessionProperties.getTaskConcurrency(session);
    }

    private double getMinDistinctValueCountEstimate(AggregationNode aggregationNode, StatsProvider statsProvider) {
        PlanNodeStatsEstimate sourceStats = statsProvider.getStats(aggregationNode.getSource());
        return aggregationNode.getGroupingKeys().stream().filter(symbol -> !Double.isNaN(sourceStats.getSymbolStatistics((Symbol)symbol).getDistinctValuesCount())).map(symbol -> sourceStats.getSymbolStatistics((Symbol)symbol).getDistinctValuesCount()).max(Double::compareTo).orElse(Double.NaN);
    }
}

