/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatsProvider;
import io.trino.cost.TaskCountEstimator;
import io.trino.metadata.Metadata;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.OptimizerConfig;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct;
import io.trino.sql.planner.iterative.rule.MultipleDistinctAggregationsToSubqueries;
import io.trino.sql.planner.iterative.rule.OptimizeMixedDistinctAggregations;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.UnionNode;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Set;

public class DistinctAggregationStrategyChooser {
    private static final int MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER = 8;
    private static final int PRE_AGGREGATE_MAX_OUTPUT_ROW_COUNT_MULTIPLIER = 64;
    private static final double MAX_JOIN_GROUPING_KEYS_SIZE = 1.048576E8;
    private final TaskCountEstimator taskCountEstimator;
    private final Metadata metadata;

    public DistinctAggregationStrategyChooser(TaskCountEstimator taskCountEstimator, Metadata metadata) {
        this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    public static DistinctAggregationStrategyChooser createDistinctAggregationStrategyChooser(TaskCountEstimator taskCountEstimator, Metadata metadata) {
        return new DistinctAggregationStrategyChooser(taskCountEstimator, metadata);
    }

    public boolean shouldAddMarkDistinct(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup) {
        return this.chooseMarkDistinctStrategy(aggregationNode, session, statsProvider, lookup) == OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT;
    }

    public boolean shouldUsePreAggregate(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup) {
        return this.chooseMarkDistinctStrategy(aggregationNode, session, statsProvider, lookup) == OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE;
    }

    public boolean shouldSplitToSubqueries(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup) {
        return this.chooseMarkDistinctStrategy(aggregationNode, session, statsProvider, lookup) == OptimizerConfig.DistinctAggregationsStrategy.SPLIT_TO_SUBQUERIES;
    }

    private OptimizerConfig.DistinctAggregationsStrategy chooseMarkDistinctStrategy(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup) {
        OptimizerConfig.DistinctAggregationsStrategy distinctAggregationsStrategy = SystemSessionProperties.distinctAggregationsStrategy(session);
        if (distinctAggregationsStrategy != OptimizerConfig.DistinctAggregationsStrategy.AUTOMATIC) {
            if (distinctAggregationsStrategy == OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT && MultipleDistinctAggregationToMarkDistinct.canUseMarkDistinct(aggregationNode)) {
                return OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT;
            }
            if (distinctAggregationsStrategy == OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE && OptimizeMixedDistinctAggregations.canUsePreAggregate(aggregationNode)) {
                return OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE;
            }
            if (distinctAggregationsStrategy == OptimizerConfig.DistinctAggregationsStrategy.SPLIT_TO_SUBQUERIES && MultipleDistinctAggregationsToSubqueries.isAggregationCandidateForSplittingToSubqueries(aggregationNode) && this.isAggregationSourceSupportedForSubqueries(aggregationNode.getSource(), session, lookup)) {
                return OptimizerConfig.DistinctAggregationsStrategy.SPLIT_TO_SUBQUERIES;
            }
            return OptimizerConfig.DistinctAggregationsStrategy.SINGLE_STEP;
        }
        double numberOfDistinctValues = this.getMinDistinctValueCountEstimate(aggregationNode, statsProvider);
        int maxNumberOfConcurrentThreadsForAggregation = this.getMaxNumberOfConcurrentThreadsForAggregation(session);
        if (!aggregationNode.getGroupingKeys().isEmpty() && !Double.isNaN(numberOfDistinctValues) && (numberOfDistinctValues > (double)(64 * maxNumberOfConcurrentThreadsForAggregation) || numberOfDistinctValues > (double)(8 * maxNumberOfConcurrentThreadsForAggregation) && aggregationNode.getGroupingKeys().size() > 2)) {
            return OptimizerConfig.DistinctAggregationsStrategy.SINGLE_STEP;
        }
        if (MultipleDistinctAggregationsToSubqueries.isAggregationCandidateForSplittingToSubqueries(aggregationNode) && this.shouldSplitAggregationToSubqueries(aggregationNode, session, statsProvider, lookup)) {
            return OptimizerConfig.DistinctAggregationsStrategy.SPLIT_TO_SUBQUERIES;
        }
        if (OptimizeMixedDistinctAggregations.canUsePreAggregate(aggregationNode) && aggregationNode.getGroupingKeys().size() <= 2) {
            return OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE;
        }
        if (MultipleDistinctAggregationToMarkDistinct.canUseMarkDistinct(aggregationNode)) {
            return OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT;
        }
        return OptimizerConfig.DistinctAggregationsStrategy.SINGLE_STEP;
    }

    private int getMaxNumberOfConcurrentThreadsForAggregation(Session session) {
        return this.taskCountEstimator.estimateHashedTaskCount(session) * SystemSessionProperties.getTaskConcurrency(session);
    }

    private double getMinDistinctValueCountEstimate(AggregationNode aggregationNode, StatsProvider statsProvider) {
        PlanNodeStatsEstimate sourceStats = statsProvider.getStats(aggregationNode.getSource());
        return aggregationNode.getGroupingKeys().stream().filter(symbol -> !Double.isNaN(sourceStats.getSymbolStatistics((Symbol)symbol).getDistinctValuesCount())).map(symbol -> sourceStats.getSymbolStatistics((Symbol)symbol).getDistinctValuesCount()).max(Double::compareTo).orElse(Double.NaN);
    }

    private boolean shouldSplitAggregationToSubqueries(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup) {
        if (!this.isAggregationSourceSupportedForSubqueries(aggregationNode.getSource(), session, lookup)) {
            return false;
        }
        if (PlanNodeSearcher.searchFrom(aggregationNode.getSource(), lookup).whereIsInstanceOfAny(UnionNode.class).findFirst().isPresent()) {
            return false;
        }
        if (PlanNodeSearcher.searchFrom(aggregationNode.getSource(), lookup).where(node -> {
            FilterNode filterNode;
            return node instanceof FilterNode && DistinctAggregationStrategyChooser.isSelective(filterNode = (FilterNode)node, statsProvider);
        }).matches()) {
            return false;
        }
        if (DistinctAggregationStrategyChooser.isAdditionalReadOverheadTooExpensive(aggregationNode, statsProvider, lookup)) {
            return false;
        }
        if (aggregationNode.hasSingleGlobalAggregation()) {
            return true;
        }
        PlanNodeStatsEstimate stats = statsProvider.getStats(aggregationNode);
        double groupingKeysSizeInBytes = stats.getOutputSizeInBytes(aggregationNode.getGroupingKeys());
        return !Double.isNaN(groupingKeysSizeInBytes) && !(groupingKeysSizeInBytes > 1.048576E8);
    }

    private static boolean isAdditionalReadOverheadTooExpensive(AggregationNode aggregationNode, StatsProvider statsProvider, Lookup lookup) {
        double distinctInputDataSize;
        long subqueryCount;
        Set distinctInputs = (Set)aggregationNode.getAggregations().values().stream().filter(AggregationNode.Aggregation::isDistinct).flatMap(aggregation -> aggregation.getArguments().stream()).filter(Reference.class::isInstance).map(Symbol::from).collect(ImmutableSet.toImmutableSet());
        TableScanNode tableScanNode = (TableScanNode)PlanNodeSearcher.searchFrom(aggregationNode.getSource(), lookup).whereIsInstanceOfAny(TableScanNode.class).findOnlyElement();
        Sets.SetView additionalColumns = Sets.difference((Set)ImmutableSet.copyOf(tableScanNode.getOutputSymbols()), (Set)distinctInputs);
        double singleTableScanDataSize = statsProvider.getStats(tableScanNode).getOutputSizeInBytes(tableScanNode.getOutputSymbols());
        double additionalColumnsDataSize = statsProvider.getStats(tableScanNode).getOutputSizeInBytes((Collection<Symbol>)additionalColumns);
        double subqueriesTotalDataSize = additionalColumnsDataSize * (double)(subqueryCount = OptimizeMixedDistinctAggregations.distinctAggregationsUniqueArgumentCount(aggregationNode)) + (distinctInputDataSize = singleTableScanDataSize - additionalColumnsDataSize);
        return Double.isNaN(subqueriesTotalDataSize) || Double.isNaN(singleTableScanDataSize) || subqueriesTotalDataSize / singleTableScanDataSize > 1.5;
    }

    private static boolean isSelective(FilterNode filterNode, StatsProvider statsProvider) {
        double filterSourceRowCount;
        double filterOutputRowCount = statsProvider.getStats(filterNode).getOutputRowCount();
        return filterOutputRowCount / (filterSourceRowCount = statsProvider.getStats(filterNode.getSource()).getOutputRowCount()) < 0.5;
    }

    private boolean isAggregationSourceSupportedForSubqueries(PlanNode source, Session session, Lookup lookup) {
        if (PlanNodeSearcher.searchFrom(source, lookup).where(node -> !(node instanceof TableScanNode) && !(node instanceof FilterNode) && !(node instanceof ProjectNode) && !(node instanceof UnionNode)).findFirst().isPresent()) {
            return false;
        }
        List<PlanNode> tableScans = PlanNodeSearcher.searchFrom(source, lookup).whereIsInstanceOfAny(TableScanNode.class).findAll();
        if (tableScans.isEmpty()) {
            return false;
        }
        return tableScans.stream().allMatch(tableScanNode -> this.metadata.allowSplittingReadIntoMultipleSubQueries(session, ((TableScanNode)tableScanNode).getTable()));
    }
}

