/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.ir.Expression;
import io.trino.sql.planner.Partitioning;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class AddIntermediateAggregations
implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().with(Patterns.Aggregation.step().equalTo((Object)AggregationNode.Step.FINAL)).with(Pattern.empty(Patterns.Aggregation.groupingColumns())).matching(node -> !node.hasOrderings());

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isEnableIntermediateAggregations(session);
    }

    @Override
    public Rule.Result apply(AggregationNode aggregation, Captures captures, Rule.Context context) {
        Lookup lookup = context.getLookup();
        PlanNodeIdAllocator idAllocator = context.getIdAllocator();
        Session session = context.getSession();
        Optional<PlanNode> rewrittenSource = this.recurseToPartial(lookup.resolve(aggregation.getSource()), lookup, idAllocator);
        if (rewrittenSource.isEmpty()) {
            return Rule.Result.empty();
        }
        PlanNode source = rewrittenSource.get();
        if (SystemSessionProperties.getTaskConcurrency(session) > 1) {
            source = ExchangeNode.partitionedExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source, new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION, (List<Symbol>)ImmutableList.of()), source.getOutputSymbols()));
            source = new AggregationNode(idAllocator.getNextId(), source, AddIntermediateAggregations.inputsAsOutputs(aggregation.getAggregations()), aggregation.getGroupingSets(), aggregation.getPreGroupedSymbols(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), aggregation.getGroupIdSymbol());
            source = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source);
        }
        return Rule.Result.ofPlanNode(aggregation.replaceChildren((List<PlanNode>)ImmutableList.of((Object)source)));
    }

    private Optional<PlanNode> recurseToPartial(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator) {
        AggregationNode aggregationNode;
        if (node instanceof AggregationNode && (aggregationNode = (AggregationNode)node).getStep() == AggregationNode.Step.PARTIAL) {
            return Optional.of(this.addGatheringIntermediate(aggregationNode, idAllocator));
        }
        if (!(node instanceof ExchangeNode) && !(node instanceof ProjectNode)) {
            return Optional.empty();
        }
        ImmutableList.Builder builder = ImmutableList.builder();
        for (PlanNode source : node.getSources()) {
            Optional<PlanNode> planNode = this.recurseToPartial(lookup.resolve(source), lookup, idAllocator);
            if (planNode.isEmpty()) {
                return Optional.empty();
            }
            builder.add((Object)planNode.get());
        }
        return Optional.of(node.replaceChildren((List<PlanNode>)builder.build()));
    }

    private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeIdAllocator idAllocator) {
        Verify.verify((boolean)aggregation.getGroupingKeys().isEmpty(), (String)"Should be an un-grouped aggregation", (Object[])new Object[0]);
        ExchangeNode gatheringExchange = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, aggregation);
        return AggregationNode.builderFrom(aggregation).setId(idAllocator.getNextId()).setSource(gatheringExchange).setAggregations(AddIntermediateAggregations.outputsAsInputs(aggregation.getAggregations())).setStep(AggregationNode.Step.INTERMEDIATE).build();
    }

    private static Map<Symbol, AggregationNode.Aggregation> outputsAsInputs(Map<Symbol, AggregationNode.Aggregation> assignments) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : assignments.entrySet()) {
            Symbol output = entry.getKey();
            AggregationNode.Aggregation aggregation = entry.getValue();
            Preconditions.checkState((boolean)aggregation.getOrderingScheme().isEmpty(), (Object)"Intermediate aggregation does not support ORDER BY");
            builder.put((Object)output, (Object)new AggregationNode.Aggregation(aggregation.getResolvedFunction(), (List<Expression>)ImmutableList.of((Object)output.toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty()));
        }
        return builder.buildOrThrow();
    }

    private static Map<Symbol, AggregationNode.Aggregation> inputsAsOutputs(Map<Symbol, AggregationNode.Aggregation> assignments) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : assignments.entrySet()) {
            Symbol input = (Symbol)Iterables.getOnlyElement(SymbolsExtractor.extractAll(entry.getValue()));
            builder.put((Object)input, (Object)entry.getValue());
        }
        return builder.buildOrThrow();
    }
}

