/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public class PullConstantsAboveGroupBy
implements Rule<AggregationNode> {
    private static final Capture<ProjectNode> SOURCE = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(agg -> agg.getGroupingSetCount() == 1).with(Patterns.source().matching(Patterns.project().capturedAs(SOURCE)));

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isOptimizeConstantGroupingKeys(session);
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(AggregationNode parent, Captures captures, Rule.Context context) {
        if (!this.isEnabled(context.getSession())) {
            return Rule.Result.empty();
        }
        ProjectNode source = (ProjectNode)captures.get(SOURCE);
        List outputVariables = parent.getOutputVariables();
        Map<VariableReferenceExpression, RowExpression> constSourceVars = PullConstantsAboveGroupBy.extractConstVars(source, outputVariables);
        List groupingKeys = parent.getGroupingKeys();
        List newGroupingKeys = (List)groupingKeys.stream().filter(key -> !constSourceVars.containsKey(key)).collect(ImmutableList.toImmutableList());
        if (constSourceVars.isEmpty() || newGroupingKeys.equals(groupingKeys)) {
            return Rule.Result.empty();
        }
        if (newGroupingKeys.isEmpty()) {
            return Rule.Result.empty();
        }
        AggregationNode newAgg = new AggregationNode(parent.getSourceLocation(), parent.getId(), (PlanNode)source, parent.getAggregations(), AggregationNode.singleGroupingSet((List)newGroupingKeys), (List)ImmutableList.of(), parent.getStep(), parent.getHashVariable(), parent.getGroupIdVariable());
        Map remainingVars = (Map)outputVariables.stream().filter(var -> !constSourceVars.containsKey(var)).collect(ImmutableMap.toImmutableMap(Function.identity(), Function.identity()));
        Assignments.Builder assignments = Assignments.builder();
        assignments.putAll(constSourceVars);
        assignments.putAll(remainingVars);
        return Rule.Result.ofPlanNode((PlanNode)new ProjectNode(parent.getSourceLocation(), context.getIdAllocator().getNextId(), (PlanNode)newAgg, assignments.build(), source.getLocality()));
    }

    private static Map<VariableReferenceExpression, RowExpression> extractConstVars(ProjectNode projectNode, List<VariableReferenceExpression> outputVariables) {
        return (Map)projectNode.getAssignments().entrySet().stream().filter(entry -> entry.getValue() instanceof ConstantExpression && outputVariables.contains(entry.getKey())).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
    }
}

