/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.cost.TaskCountEstimator;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.sql.planner.NodeAndMappings;
import io.trino.sql.planner.PlanCopier;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.DistinctAggregationStrategyChooser;
import io.trino.sql.planner.iterative.rule.OptimizeMixedDistinctAggregations;
import io.trino.sql.planner.iterative.rule.SingleDistinctAggregationToGroupBy;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;

public class MultipleDistinctAggregationsToSubqueries
implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(MultipleDistinctAggregationsToSubqueries::isAggregationCandidateForSplittingToSubqueries);
    private final DistinctAggregationStrategyChooser distinctAggregationStrategyChooser;

    public static boolean isAggregationCandidateForSplittingToSubqueries(AggregationNode aggregationNode) {
        return SingleDistinctAggregationToGroupBy.allDistinctAggregates(aggregationNode) && OptimizeMixedDistinctAggregations.hasMultipleDistincts(aggregationNode) && aggregationNode.getGroupingSetCount() == 1 && aggregationNode.getHashSymbol().isEmpty();
    }

    public MultipleDistinctAggregationsToSubqueries(TaskCountEstimator taskCountEstimator, Metadata metadata) {
        this.distinctAggregationStrategyChooser = DistinctAggregationStrategyChooser.createDistinctAggregationStrategyChooser(taskCountEstimator, metadata);
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        if (!this.distinctAggregationStrategyChooser.shouldSplitToSubqueries(aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup())) {
            return Rule.Result.empty();
        }
        LinkedHashMap<Set, Map> aggregationsByArguments = new LinkedHashMap<Set, Map>(aggregationNode.getAggregations().size());
        List sortedAggregations = (List)aggregationNode.getAggregations().entrySet().stream().sorted(Comparator.comparing(entry -> ((Symbol)entry.getKey()).name())).collect(ImmutableList.toImmutableList());
        for (Map.Entry entry2 : sortedAggregations) {
            aggregationsByArguments.compute((Set)ImmutableSet.copyOf(((AggregationNode.Aggregation)entry2.getValue()).getArguments()), (set, current) -> {
                if (current == null) {
                    current = new HashMap<Symbol, AggregationNode.Aggregation>();
                }
                current.put((Symbol)entry2.getKey(), (AggregationNode.Aggregation)entry2.getValue());
                return current;
            });
        }
        PlanNode right = null;
        List<Symbol> rightJoinSymbols = null;
        Assignments.Builder assignments = Assignments.builder();
        ImmutableList aggregationsByArgumentsList = ImmutableList.copyOf(aggregationsByArguments.values());
        for (int i = aggregationsByArgumentsList.size() - 1; i > 0; --i) {
            Map aggregations = (Map)aggregationsByArgumentsList.get(i);
            AggregationNode subAggregationNode = this.buildSubAggregation(aggregationNode, aggregations, assignments, context);
            if (right == null) {
                right = subAggregationNode;
                rightJoinSymbols = subAggregationNode.getGroupingKeys();
                continue;
            }
            right = this.buildJoin(subAggregationNode, subAggregationNode.getGroupingKeys(), right, rightJoinSymbols, context);
        }
        AggregationNode left = this.buildSubAggregation(aggregationNode, (Map)aggregationsByArgumentsList.getFirst(), assignments, context);
        for (int i = 0; i < left.getGroupingKeys().size(); ++i) {
            assignments.put(aggregationNode.getGroupingKeys().get(i), left.getGroupingKeys().get(i).toSymbolReference());
        }
        JoinNode topJoin = this.buildJoin(left, left.getGroupingKeys(), right, rightJoinSymbols, context);
        ProjectNode result = new ProjectNode(aggregationNode.getId(), topJoin, assignments.build());
        return Rule.Result.ofPlanNode(result);
    }

    private AggregationNode buildSubAggregation(AggregationNode aggregationNode, Map<Symbol, AggregationNode.Aggregation> aggregations, Assignments.Builder assignments, Rule.Context context) {
        ImmutableList originalAggregationOutputSymbols = ImmutableList.copyOf(aggregations.keySet());
        NodeAndMappings copied = PlanCopier.copyPlan(AggregationNode.builderFrom(aggregationNode).setAggregations(aggregations).build(), (List<Symbol>)originalAggregationOutputSymbols, context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup());
        AggregationNode subAggregationNode = (AggregationNode)copied.getNode();
        for (int i = 0; i < originalAggregationOutputSymbols.size(); ++i) {
            assignments.put((Symbol)originalAggregationOutputSymbols.get(i), copied.getFields().get(i).toSymbolReference());
        }
        return subAggregationNode;
    }

    private JoinNode buildJoin(PlanNode left, List<Symbol> leftJoinSymbols, PlanNode right, List<Symbol> rightJoinSymbols, Rule.Context context) {
        Preconditions.checkArgument((leftJoinSymbols.size() == rightJoinSymbols.size() ? 1 : 0) != 0);
        List criteria = (List)IntStream.range(0, leftJoinSymbols.size()).mapToObj(i -> new JoinNode.EquiJoinClause((Symbol)leftJoinSymbols.get(i), (Symbol)rightJoinSymbols.get(i))).collect(ImmutableList.toImmutableList());
        return new JoinNode(context.getIdAllocator().getNextId(), JoinType.INNER, left, right, criteria, left.getOutputSymbols(), right.getOutputSymbols(), false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), (Map<DynamicFilterId, Symbol>)ImmutableMap.of(), Optional.empty());
    }
}

