/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.SystemSessionProperties;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.function.AggregationFunctionMetadata;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Lambda;
import io.trino.sql.planner.Partitioning;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.SymbolMapper;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

public class PushPartialAggregationThroughExchange
implements Rule<AggregationNode> {
    private final PlannerContext plannerContext;
    private static final Capture<ExchangeNode> EXCHANGE_NODE = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().with(Patterns.source().matching(Patterns.exchange().matching(node -> node.getOrderingScheme().isEmpty()).capturedAs(EXCHANGE_NODE)));

    public PushPartialAggregationThroughExchange(PlannerContext plannerContext) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        ExchangeNode exchangeNode = (ExchangeNode)captures.get(EXCHANGE_NODE);
        boolean decomposable = aggregationNode.isDecomposable(context.getSession(), this.plannerContext.getMetadata());
        if (aggregationNode.getStep() == AggregationNode.Step.SINGLE && aggregationNode.hasEmptyGroupingSet() && aggregationNode.hasNonEmptyGroupingSet() && exchangeNode.getType() == ExchangeNode.Type.REPARTITION) {
            Preconditions.checkState((boolean)decomposable, (Object)"Distributed aggregation with empty grouping set requires partial but functions are not decomposable");
            return Rule.Result.ofPlanNode(this.split(aggregationNode, context));
        }
        if (!decomposable || !SystemSessionProperties.preferPartialAggregation(context.getSession())) {
            return Rule.Result.empty();
        }
        if (exchangeNode.getType() != ExchangeNode.Type.GATHER && exchangeNode.getType() != ExchangeNode.Type.REPARTITION || exchangeNode.getPartitioningScheme().isReplicateNullsAndAny()) {
            return Rule.Result.empty();
        }
        if (exchangeNode.getType() == ExchangeNode.Type.REPARTITION) {
            List partitioningColumns = exchangeNode.getPartitioningScheme().getPartitioning().getArguments().stream().filter(Partitioning.ArgumentBinding::isVariable).map(Partitioning.ArgumentBinding::getColumn).collect(Collectors.toList());
            if (!aggregationNode.getGroupingKeys().containsAll(partitioningColumns)) {
                return Rule.Result.empty();
            }
        }
        if (aggregationNode.getHashSymbol().isPresent() || exchangeNode.getPartitioningScheme().getHashColumn().isPresent()) {
            return Rule.Result.empty();
        }
        return switch (aggregationNode.getStep()) {
            case AggregationNode.Step.SINGLE -> Rule.Result.ofPlanNode(this.split(aggregationNode, context));
            case AggregationNode.Step.PARTIAL -> Rule.Result.ofPlanNode(this.pushPartial(aggregationNode, exchangeNode, context));
            default -> Rule.Result.empty();
        };
    }

    private PlanNode pushPartial(AggregationNode aggregation, ExchangeNode exchange, Rule.Context context) {
        ArrayList<PlanNode> partials = new ArrayList<PlanNode>();
        for (int i = 0; i < exchange.getSources().size(); ++i) {
            PlanNode source = exchange.getSources().get(i);
            SymbolMapper.Builder mappingsBuilder = SymbolMapper.builder();
            for (int outputIndex = 0; outputIndex < exchange.getOutputSymbols().size(); ++outputIndex) {
                Symbol input;
                Symbol output = exchange.getOutputSymbols().get(outputIndex);
                if (output.equals(input = exchange.getInputs().get(i).get(outputIndex))) continue;
                mappingsBuilder.put(output, input);
            }
            SymbolMapper symbolMapper = mappingsBuilder.build();
            AggregationNode mappedPartial = symbolMapper.map(aggregation, source, context.getIdAllocator().getNextId());
            mappedPartial = AggregationNode.builderFrom(mappedPartial).setIsInputReducingAggregation(true).build();
            Assignments.Builder assignments = Assignments.builder();
            for (Symbol output : aggregation.getOutputSymbols()) {
                Symbol input = symbolMapper.map(output);
                assignments.put(output, input.toSymbolReference());
            }
            partials.add(new ProjectNode(context.getIdAllocator().getNextId(), mappedPartial, assignments.build()));
        }
        for (PlanNode node : partials) {
            Verify.verify((boolean)aggregation.getOutputSymbols().equals(node.getOutputSymbols()));
        }
        PartitioningScheme partitioning = new PartitioningScheme(exchange.getPartitioningScheme().getPartitioning(), aggregation.getOutputSymbols(), exchange.getPartitioningScheme().getHashColumn(), exchange.getPartitioningScheme().isReplicateNullsAndAny(), exchange.getPartitioningScheme().getBucketToPartition(), exchange.getPartitioningScheme().getPartitionCount());
        return new ExchangeNode(context.getIdAllocator().getNextId(), exchange.getType(), exchange.getScope(), partitioning, partials, (List<List<Symbol>>)ImmutableList.copyOf(Collections.nCopies(partials.size(), aggregation.getOutputSymbols())), Optional.empty());
    }

    private PlanNode split(AggregationNode node, Rule.Context context) {
        ImmutableMap.Builder intermediateAggregation = ImmutableMap.builder();
        ImmutableMap.Builder finalAggregation = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
            AggregationNode.Aggregation originalAggregation = entry.getValue();
            ResolvedFunction resolvedFunction = originalAggregation.getResolvedFunction();
            AggregationFunctionMetadata functionMetadata = this.plannerContext.getMetadata().getAggregationFunctionMetadata(context.getSession(), resolvedFunction);
            List intermediateTypes = (List)functionMetadata.getIntermediateTypes().stream().map(arg_0 -> ((TypeManager)this.plannerContext.getTypeManager()).getType(arg_0)).collect(ImmutableList.toImmutableList());
            RowType intermediateType = intermediateTypes.size() == 1 ? (Type)intermediateTypes.get(0) : RowType.anonymous((List)intermediateTypes);
            Symbol intermediateSymbol = context.getSymbolAllocator().newSymbol(resolvedFunction.signature().getName().getFunctionName(), (Type)intermediateType);
            Preconditions.checkState((boolean)originalAggregation.getOrderingScheme().isEmpty(), (Object)"Aggregate with ORDER BY does not support partial aggregation");
            intermediateAggregation.put((Object)intermediateSymbol, (Object)new AggregationNode.Aggregation(resolvedFunction, originalAggregation.getArguments(), originalAggregation.isDistinct(), originalAggregation.getFilter(), originalAggregation.getOrderingScheme(), originalAggregation.getMask()));
            finalAggregation.put((Object)entry.getKey(), (Object)new AggregationNode.Aggregation(resolvedFunction, (List<Expression>)ImmutableList.builder().add((Object)intermediateSymbol.toSymbolReference()).addAll((Iterable)originalAggregation.getArguments().stream().filter(Lambda.class::isInstance).collect(ImmutableList.toImmutableList())).build(), false, Optional.empty(), Optional.empty(), Optional.empty()));
        }
        AggregationNode partial = new AggregationNode(context.getIdAllocator().getNextId(), node.getSource(), (Map<Symbol, AggregationNode.Aggregation>)intermediateAggregation.buildOrThrow(), node.getGroupingSets(), (List<Symbol>)ImmutableList.of(), AggregationNode.Step.PARTIAL, node.getHashSymbol(), node.getGroupIdSymbol());
        return new AggregationNode(node.getId(), partial, (Map<Symbol, AggregationNode.Aggregation>)finalAggregation.buildOrThrow(), node.getGroupingSets(), (List<Symbol>)ImmutableList.of(), AggregationNode.Step.FINAL, node.getHashSymbol(), node.getGroupIdSymbol());
    }
}

