/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultiset;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multiset;
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.cost.TaskCountEstimator;
import io.trino.execution.TaskManagerConfig;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.Partitioning;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.StreamPreferredProperties;
import io.trino.sql.planner.optimizations.StreamPropertyDerivations;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.GroupIdNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class AddExchangesBelowPartialAggregationOverGroupIdRuleSet {
    private static final Capture<ProjectNode> PROJECTION = Capture.newCapture();
    private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
    private static final Capture<GroupIdNode> GROUP_ID = Capture.newCapture();
    private static final Pattern<ExchangeNode> WITH_PROJECTION = Pattern.typeOf(ExchangeNode.class).with(Patterns.Exchange.scope().equalTo((Object)ExchangeNode.Scope.REMOTE)).with(Patterns.source().matching(Pattern.typeOf(ProjectNode.class).capturedAs(PROJECTION).with(Patterns.source().matching(Pattern.typeOf(AggregationNode.class).capturedAs(AGGREGATION).with(Patterns.Aggregation.step().equalTo((Object)AggregationNode.Step.PARTIAL)).with(Pattern.nonEmpty(Patterns.Aggregation.groupingColumns())).with(Patterns.source().matching(Pattern.typeOf(GroupIdNode.class).capturedAs(GROUP_ID)))))));
    private static final Pattern<ExchangeNode> WITHOUT_PROJECTION = Pattern.typeOf(ExchangeNode.class).with(Patterns.Exchange.scope().equalTo((Object)ExchangeNode.Scope.REMOTE)).with(Patterns.source().matching(Pattern.typeOf(AggregationNode.class).capturedAs(AGGREGATION).with(Patterns.Aggregation.step().equalTo((Object)AggregationNode.Step.PARTIAL)).with(Pattern.nonEmpty(Patterns.Aggregation.groupingColumns())).with(Patterns.source().matching(Pattern.typeOf(GroupIdNode.class).capturedAs(GROUP_ID)))));
    private static final double GROUPING_SETS_SYMBOL_REQUIRED_FREQUENCY = 0.5;
    private static final double ANTI_SKEWNESS_MARGIN = 3.0;
    private final PlannerContext plannerContext;
    private final TypeAnalyzer typeAnalyzer;
    private final TaskCountEstimator taskCountEstimator;
    private final DataSize maxPartialAggregationMemoryUsage;

    public AddExchangesBelowPartialAggregationOverGroupIdRuleSet(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, TaskCountEstimator taskCountEstimator, TaskManagerConfig taskManagerConfig) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
        this.typeAnalyzer = Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
        this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
        this.maxPartialAggregationMemoryUsage = Objects.requireNonNull(taskManagerConfig, "taskManagerConfig is null").getMaxPartialAggregationMemoryUsage();
    }

    public Set<Rule<?>> rules() {
        return ImmutableSet.of((Object)new AddExchangesBelowProjectionPartialAggregationGroupId(), (Object)new AddExchangesBelowExchangePartialAggregationGroupId());
    }

    private abstract class BaseAddExchangesBelowExchangePartialAggregationGroupId
    implements Rule<ExchangeNode> {
        private BaseAddExchangesBelowExchangePartialAggregationGroupId() {
        }

        @Override
        public boolean isEnabled(Session session) {
            if (!SystemSessionProperties.isEnableStatsCalculator(session)) {
                return false;
            }
            return SystemSessionProperties.isEnableForcedExchangeBelowGroupId(session);
        }

        protected Optional<PlanNode> transform(AggregationNode aggregation, GroupIdNode groupId, Rule.Context context) {
            StreamPropertyDerivations.StreamProperties sourceProperties;
            if (groupId.getGroupingSets().size() < 2) {
                return Optional.empty();
            }
            Set groupingKeys = (Set)aggregation.getGroupingKeys().stream().filter(symbol -> !groupId.getGroupIdSymbol().equals(symbol)).collect(ImmutableSet.toImmutableSet());
            Multiset groupingSetHistogram = (Multiset)groupId.getGroupingSets().stream().flatMap(Collection::stream).collect(ImmutableMultiset.toImmutableMultiset());
            if (!Objects.equals(groupingSetHistogram.elementSet(), groupingKeys)) {
                return Optional.empty();
            }
            double aggregationMemoryRequirements = this.estimateAggregationMemoryRequirements(groupingKeys, groupId, (Multiset<Symbol>)groupingSetHistogram, context);
            if (Double.isNaN(aggregationMemoryRequirements) || aggregationMemoryRequirements < (double)AddExchangesBelowPartialAggregationOverGroupIdRuleSet.this.maxPartialAggregationMemoryUsage.toBytes()) {
                return Optional.empty();
            }
            List desiredHashSymbols = (List)groupingSetHistogram.entrySet().stream().filter(entry -> (double)entry.getCount() >= (double)groupId.getGroupingSets().size() * 0.5).map(Multiset.Entry::getElement).peek(symbol -> Verify.verify((boolean)groupingKeys.contains(symbol))).map(groupId.getGroupingColumns()::get).collect(ImmutableList.toImmutableList());
            StreamPreferredProperties requiredProperties = StreamPreferredProperties.fixedParallelism().withPartitioning(desiredHashSymbols);
            if (requiredProperties.isSatisfiedBy(sourceProperties = this.derivePropertiesRecursively(groupId.getSource(), context))) {
                return Optional.empty();
            }
            double estimatedGroups = this.estimatedGroupCount(desiredHashSymbols, context.getStatsProvider().getStats(groupId.getSource()));
            if (Double.isNaN(estimatedGroups) || estimatedGroups * 3.0 < (double)this.maximalConcurrencyAfterRepartition(context)) {
                return Optional.empty();
            }
            PlanNode source = groupId.getSource();
            source = ExchangeNode.partitionedExchange(context.getIdAllocator().getNextId(), ExchangeNode.Scope.REMOTE, source, new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION, desiredHashSymbols), source.getOutputSymbols()));
            source = ExchangeNode.partitionedExchange(context.getIdAllocator().getNextId(), ExchangeNode.Scope.LOCAL, source, new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION, desiredHashSymbols), source.getOutputSymbols()));
            PlanNode newGroupId = groupId.replaceChildren((List<PlanNode>)ImmutableList.of((Object)source));
            PlanNode newAggregation = aggregation.replaceChildren((List<PlanNode>)ImmutableList.of((Object)newGroupId));
            return Optional.of(newAggregation);
        }

        private int maximalConcurrencyAfterRepartition(Rule.Context context) {
            return SystemSessionProperties.getTaskConcurrency(context.getSession()) * AddExchangesBelowPartialAggregationOverGroupIdRuleSet.this.taskCountEstimator.estimateHashedTaskCount(context.getSession());
        }

        private double estimateAggregationMemoryRequirements(Set<Symbol> groupingKeys, GroupIdNode groupId, Multiset<Symbol> groupingSetHistogram, Rule.Context context) {
            Preconditions.checkArgument((boolean)Objects.equals(groupingSetHistogram.elementSet(), groupingKeys));
            PlanNodeStatsEstimate sourceStats = context.getStatsProvider().getStats(groupId.getSource());
            double keysMemoryRequirements = 0.0;
            for (List<Symbol> groupingSet : groupId.getGroupingSets()) {
                List sourceSymbols = (List)groupingSet.stream().map(groupId.getGroupingColumns()::get).collect(ImmutableList.toImmutableList());
                double keyWidth = sourceStats.getOutputSizeInBytes(sourceSymbols, context.getSymbolAllocator().getTypes()) / sourceStats.getOutputRowCount();
                double keyNdv = Math.min(this.estimatedGroupCount(sourceSymbols, sourceStats), sourceStats.getOutputRowCount());
                keysMemoryRequirements += keyWidth * keyNdv;
            }
            return keysMemoryRequirements;
        }

        private double estimatedGroupCount(List<Symbol> symbols, PlanNodeStatsEstimate statsEstimate) {
            return symbols.stream().map(statsEstimate::getSymbolStatistics).mapToDouble(this::ndvIncludingNull).reduce(1.0, (a, b) -> a * b);
        }

        private double ndvIncludingNull(SymbolStatsEstimate symbolStatsEstimate) {
            if (symbolStatsEstimate.getNullsFraction() == 0.0) {
                return symbolStatsEstimate.getDistinctValuesCount();
            }
            return symbolStatsEstimate.getDistinctValuesCount() + 1.0;
        }

        private StreamPropertyDerivations.StreamProperties derivePropertiesRecursively(PlanNode node, Rule.Context context) {
            PlanNode resolvedPlanNode = context.getLookup().resolve(node);
            List inputProperties = (List)resolvedPlanNode.getSources().stream().map(source -> this.derivePropertiesRecursively((PlanNode)source, context)).collect(ImmutableList.toImmutableList());
            return StreamPropertyDerivations.deriveProperties(resolvedPlanNode, inputProperties, AddExchangesBelowPartialAggregationOverGroupIdRuleSet.this.plannerContext, context.getSession(), context.getSymbolAllocator().getTypes(), AddExchangesBelowPartialAggregationOverGroupIdRuleSet.this.typeAnalyzer);
        }
    }

    private class AddExchangesBelowExchangePartialAggregationGroupId
    extends BaseAddExchangesBelowExchangePartialAggregationGroupId {
        private AddExchangesBelowExchangePartialAggregationGroupId() {
        }

        @Override
        public Pattern<ExchangeNode> getPattern() {
            return WITHOUT_PROJECTION;
        }

        @Override
        public Rule.Result apply(ExchangeNode exchange, Captures captures, Rule.Context context) {
            AggregationNode aggregation = (AggregationNode)captures.get(AGGREGATION);
            GroupIdNode groupId = (GroupIdNode)captures.get(GROUP_ID);
            return this.transform(aggregation, groupId, context).map(newAggregation -> {
                PlanNode newExchange = exchange.replaceChildren((List<PlanNode>)ImmutableList.of((Object)newAggregation));
                return Rule.Result.ofPlanNode(newExchange);
            }).orElseGet(Rule.Result::empty);
        }
    }

    private class AddExchangesBelowProjectionPartialAggregationGroupId
    extends BaseAddExchangesBelowExchangePartialAggregationGroupId {
        private AddExchangesBelowProjectionPartialAggregationGroupId() {
        }

        @Override
        public Pattern<ExchangeNode> getPattern() {
            return WITH_PROJECTION;
        }

        @Override
        public Rule.Result apply(ExchangeNode exchange, Captures captures, Rule.Context context) {
            ProjectNode project = (ProjectNode)captures.get(PROJECTION);
            AggregationNode aggregation = (AggregationNode)captures.get(AGGREGATION);
            GroupIdNode groupId = (GroupIdNode)captures.get(GROUP_ID);
            return this.transform(aggregation, groupId, context).map(newAggregation -> Rule.Result.ofPlanNode(exchange.replaceChildren((List<PlanNode>)ImmutableList.of((Object)project.replaceChildren((List<PlanNode>)ImmutableList.of((Object)newAggregation)))))).orElseGet(Rule.Result::empty);
        }
    }
}

