/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class AddIntermediateAggregations
implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().with(Patterns.Aggregation.step().equalTo((Object)AggregationNode.Step.FINAL)).with(Pattern.empty(Patterns.Aggregation.groupingColumns())).matching(node -> !node.hasOrderings());

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isEnableIntermediateAggregations(session);
    }

    @Override
    public Rule.Result apply(AggregationNode aggregation, Captures captures, Rule.Context context) {
        Lookup lookup = context.getLookup();
        PlanNodeIdAllocator idAllocator = context.getIdAllocator();
        Session session = context.getSession();
        TypeProvider types = TypeProvider.viewOf(context.getVariableAllocator().getVariables());
        Optional<PlanNode> rewrittenSource = this.recurseToPartial(lookup.resolve(aggregation.getSource()), lookup, idAllocator, types);
        if (!rewrittenSource.isPresent()) {
            return Rule.Result.empty();
        }
        PlanNode source = rewrittenSource.get();
        if (SystemSessionProperties.getTaskConcurrency(session) > 1) {
            Map<VariableReferenceExpression, AggregationNode.Aggregation> variableToAggregations = AddIntermediateAggregations.inputsAsOutputs(aggregation.getAggregations(), types);
            if (variableToAggregations.isEmpty()) {
                return Rule.Result.empty();
            }
            source = ExchangeNode.roundRobinExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source);
            source = new AggregationNode(aggregation.getSourceLocation(), idAllocator.getNextId(), source, variableToAggregations, aggregation.getGroupingSets(), aggregation.getPreGroupedVariables(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashVariable(), aggregation.getGroupIdVariable(), aggregation.getAggregationId());
            source = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source);
        }
        return Rule.Result.ofPlanNode(aggregation.replaceChildren((List)ImmutableList.of((Object)source)));
    }

    private Optional<PlanNode> recurseToPartial(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, TypeProvider types) {
        if (node instanceof AggregationNode && ((AggregationNode)node).getStep() == AggregationNode.Step.PARTIAL) {
            return Optional.of(this.addGatheringIntermediate((AggregationNode)node, idAllocator, types));
        }
        if (!(node instanceof ExchangeNode) && !(node instanceof ProjectNode)) {
            return Optional.empty();
        }
        ImmutableList.Builder builder = ImmutableList.builder();
        for (PlanNode source : node.getSources()) {
            Optional<PlanNode> planNode = this.recurseToPartial(lookup.resolve(source), lookup, idAllocator, types);
            if (!planNode.isPresent()) {
                return Optional.empty();
            }
            builder.add((Object)planNode.get());
        }
        return Optional.of(node.replaceChildren((List)builder.build()));
    }

    private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeIdAllocator idAllocator, TypeProvider types) {
        Verify.verify((boolean)aggregation.getGroupingKeys().isEmpty(), (String)"Should be an un-grouped aggregation", (Object[])new Object[0]);
        ExchangeNode gatheringExchange = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, (PlanNode)aggregation);
        return new AggregationNode(aggregation.getSourceLocation(), idAllocator.getNextId(), (PlanNode)gatheringExchange, AddIntermediateAggregations.outputsAsInputs(aggregation.getAggregations()), aggregation.getGroupingSets(), aggregation.getPreGroupedVariables(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashVariable(), aggregation.getGroupIdVariable(), aggregation.getAggregationId());
    }

    private static Map<VariableReferenceExpression, AggregationNode.Aggregation> outputsAsInputs(Map<VariableReferenceExpression, AggregationNode.Aggregation> assignments) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : assignments.entrySet()) {
            VariableReferenceExpression output = entry.getKey();
            AggregationNode.Aggregation aggregation = entry.getValue();
            Preconditions.checkState((!aggregation.getOrderBy().isPresent() ? 1 : 0) != 0, (Object)"Intermediate aggregation does not support ORDER BY");
            AddIntermediateAggregations.appendAggregation((ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation>)builder, aggregation, output, aggregation.getCall().getType());
        }
        return builder.build();
    }

    private static Map<VariableReferenceExpression, AggregationNode.Aggregation> inputsAsOutputs(Map<VariableReferenceExpression, AggregationNode.Aggregation> assignments, TypeProvider types) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : assignments.entrySet()) {
            AggregationNode.Aggregation aggregation = entry.getValue();
            if (aggregation.getArguments().size() != 1 || aggregation.getOrderBy().isPresent() || aggregation.getFilter().isPresent()) {
                return ImmutableMap.of();
            }
            VariableReferenceExpression input = (VariableReferenceExpression)Iterables.getOnlyElement(AggregationNodeUtils.extractAggregationUniqueVariables(entry.getValue(), types));
            RowExpression argumentExpr = (RowExpression)aggregation.getCall().getArguments().get(0);
            Type returnType = argumentExpr.getType();
            AddIntermediateAggregations.appendAggregation((ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation>)builder, aggregation, input, returnType);
        }
        return builder.build();
    }

    private static void appendAggregation(ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> builder, AggregationNode.Aggregation aggregation, VariableReferenceExpression varRef, Type returnType) {
        builder.put((Object)varRef, (Object)new AggregationNode.Aggregation(new CallExpression(aggregation.getCall().getSourceLocation(), aggregation.getCall().getDisplayName(), aggregation.getCall().getFunctionHandle(), returnType, (List)ImmutableList.of((Object)varRef)), Optional.empty(), Optional.empty(), false, Optional.empty()));
    }
}

