/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.ErrorCodeSupplier;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.eventlistener.PlanOptimizerInformation;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PartitioningScheme;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.ChildReplacer;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.relational.Expressions;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

public class MergePartialAggregationsWithFilter
implements PlanOptimizer {
    private final FunctionAndTypeManager functionAndTypeManager;

    public MergePartialAggregationsWithFilter(FunctionAndTypeManager functionAndTypeManager) {
        this.functionAndTypeManager = Objects.requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) {
        if (SystemSessionProperties.isMergeAggregationsWithAndWithoutFilter(session)) {
            return SimplePlanRewriter.rewriteWith(new Rewriter(session, variableAllocator, idAllocator, this.functionAndTypeManager), plan, new Context());
        }
        return plan;
    }

    private static class Rewriter
    extends SimplePlanRewriter<Context> {
        private final Session session;
        private final VariableAllocator variableAllocator;
        private final PlanNodeIdAllocator planNodeIdAllocator;
        private final FunctionAndTypeManager functionAndTypeManager;

        public Rewriter(Session session, VariableAllocator variableAllocator, PlanNodeIdAllocator planNodeIdAllocator, FunctionAndTypeManager functionAndTypeManager) {
            this.session = Objects.requireNonNull(session, "session is null");
            this.variableAllocator = Objects.requireNonNull(variableAllocator, "variableAllocator is null");
            this.planNodeIdAllocator = Objects.requireNonNull(planNodeIdAllocator, "planNodeIdAllocator is null");
            this.functionAndTypeManager = Objects.requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
        }

        public static RowExpression ifThenElse(RowExpression ... arguments) {
            return Expressions.specialForm(SpecialFormExpression.Form.IF, arguments[1].getType(), arguments);
        }

        @Override
        public PlanNode visitPlan(PlanNode node, SimplePlanRewriter.RewriteContext<Context> context) {
            List children = (List)node.getSources().stream().map(child -> context.rewrite((PlanNode)child, (Context)context.get())).collect(ImmutableList.toImmutableList());
            if (!context.get().isEmpty()) {
                throw new PrestoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, "Unexpected plan node between partial and final aggregation");
            }
            return ChildReplacer.replaceChildren(node, children);
        }

        public PlanNode visitAggregation(AggregationNode node, SimplePlanRewriter.RewriteContext<Context> context) {
            boolean canOptimize;
            PlanNode rewrittenSource = context.rewrite(node.getSource(), context.get());
            boolean bl = canOptimize = !node.getGroupingKeys().isEmpty() && node.getAggregations().values().stream().map(x -> this.functionAndTypeManager.getFunctionMetadata(x.getFunctionHandle())).noneMatch(x -> x.isCalledOnNullInput());
            if (canOptimize) {
                Preconditions.checkState((boolean)node.getAggregations().values().stream().noneMatch(x -> x.getFilter().isPresent()), (Object)"All aggregation filters should already be rewritten to mask before this optimization");
                if (node.getStep().equals((Object)AggregationNode.Step.PARTIAL)) {
                    return this.createPartialAggregationNode(node, rewrittenSource, context);
                }
                if (node.getStep().equals((Object)AggregationNode.Step.FINAL)) {
                    return this.createFinalAggregationNode(node, rewrittenSource, context);
                }
            }
            return node.replaceChildren((List)ImmutableList.of((Object)rewrittenSource));
        }

        private AggregationNode createPartialAggregationNode(AggregationNode node, PlanNode rewrittenSource, SimplePlanRewriter.RewriteContext<Context> context) {
            Preconditions.checkState((boolean)context.get().isEmpty(), (Object)"There should be no partial aggregation left unmerged for a partial aggregation node");
            Map aggregationsWithoutMaskToOutput = (Map)node.getAggregations().entrySet().stream().filter(x -> !((AggregationNode.Aggregation)x.getValue()).getMask().isPresent()).collect(ImmutableMap.toImmutableMap(x -> (AggregationNode.Aggregation)x.getValue(), x -> (VariableReferenceExpression)x.getKey(), (a, b) -> a));
            Map aggregationsToMergeOutput = (Map)node.getAggregations().entrySet().stream().filter(x -> ((AggregationNode.Aggregation)x.getValue()).getMask().isPresent() && aggregationsWithoutMaskToOutput.containsKey(AggregationNodeUtils.removeFilterAndMask((AggregationNode.Aggregation)x.getValue()))).collect(ImmutableMap.toImmutableMap(x -> (AggregationNode.Aggregation)x.getValue(), x -> (VariableReferenceExpression)x.getKey()));
            context.get().getPartialResultToMask().putAll((Map)aggregationsToMergeOutput.entrySet().stream().collect(ImmutableMap.toImmutableMap(x -> (VariableReferenceExpression)x.getValue(), x -> (VariableReferenceExpression)((AggregationNode.Aggregation)x.getKey()).getMask().get())));
            context.get().getPartialOutputMapping().putAll((Map)aggregationsToMergeOutput.entrySet().stream().collect(ImmutableMap.toImmutableMap(x -> (VariableReferenceExpression)x.getValue(), x -> (VariableReferenceExpression)aggregationsWithoutMaskToOutput.get(AggregationNodeUtils.removeFilterAndMask((AggregationNode.Aggregation)x.getKey())))));
            HashSet<VariableReferenceExpression> maskVariables = new HashSet<VariableReferenceExpression>(context.get().getPartialResultToMask().values());
            if (maskVariables.isEmpty()) {
                return (AggregationNode)node.replaceChildren((List)ImmutableList.of((Object)rewrittenSource));
            }
            ImmutableList.Builder groupingVariables = ImmutableList.builder();
            AggregationNode.GroupingSetDescriptor groupingSetDescriptor = node.getGroupingSets();
            groupingVariables.addAll((Iterable)groupingSetDescriptor.getGroupingKeys());
            groupingVariables.addAll(maskVariables);
            AggregationNode.GroupingSetDescriptor partialGroupingSetDescriptor = new AggregationNode.GroupingSetDescriptor((List)groupingVariables.build(), groupingSetDescriptor.getGroupingSetCount(), groupingSetDescriptor.getGlobalGroupingSets());
            HashSet partialResultToMerge = new HashSet(aggregationsToMergeOutput.values());
            Map newAggregations = (Map)node.getAggregations().entrySet().stream().filter(x -> !partialResultToMerge.contains(x.getKey())).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
            this.session.getOptimizerInformationCollector().addInformation(new PlanOptimizerInformation(MergePartialAggregationsWithFilter.class.getSimpleName(), true, Optional.empty()));
            return new AggregationNode(node.getSourceLocation(), node.getId(), rewrittenSource, newAggregations, partialGroupingSetDescriptor, node.getPreGroupedVariables(), AggregationNode.Step.PARTIAL, node.getHashVariable(), node.getGroupIdVariable());
        }

        private AggregationNode createFinalAggregationNode(AggregationNode node, PlanNode rewrittenSource, SimplePlanRewriter.RewriteContext<Context> context) {
            if (context.get().isEmpty()) {
                return (AggregationNode)node.replaceChildren((List)ImmutableList.of((Object)rewrittenSource));
            }
            List intermediateVariables = node.getAggregations().values().stream().map(x -> (VariableReferenceExpression)x.getArguments().get(0)).collect(Collectors.toList());
            Preconditions.checkState((boolean)intermediateVariables.containsAll(context.get().partialResultToMask.keySet()));
            ImmutableList.Builder projectionsFromPartialAgg = ImmutableList.builder();
            ImmutableList.Builder variablesForPartialAggResult = ImmutableList.builder();
            ImmutableMap.Builder newFinalAggregationMap = ImmutableMap.builder();
            for (Map.Entry entry : node.getAggregations().entrySet()) {
                AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation)entry.getValue();
                Preconditions.checkState((aggregation.getArguments().size() > 0 && aggregation.getArguments().get(0) instanceof VariableReferenceExpression ? 1 : 0) != 0);
                VariableReferenceExpression partialInput = (VariableReferenceExpression)aggregation.getArguments().get(0);
                if (!context.get().partialResultToMask.containsKey(partialInput)) {
                    newFinalAggregationMap.put(entry.getKey(), entry.getValue());
                    continue;
                }
                VariableReferenceExpression maskVariable = context.get().getPartialResultToMask().get(partialInput);
                VariableReferenceExpression toMergePartialInput = context.get().getPartialOutputMapping().get(partialInput);
                RowExpression conditionalResult = Rewriter.ifThenElse(new RowExpression[]{maskVariable, toMergePartialInput, Expressions.constantNull(toMergePartialInput.getType())});
                projectionsFromPartialAgg.add((Object)conditionalResult);
                VariableReferenceExpression maskedPartialResult = this.variableAllocator.newVariable(toMergePartialInput);
                variablesForPartialAggResult.add((Object)maskedPartialResult);
                CallExpression originalExpression = aggregation.getCall();
                CallExpression newExpression = new CallExpression(originalExpression.getSourceLocation(), originalExpression.getDisplayName(), originalExpression.getFunctionHandle(), originalExpression.getType(), (List)ImmutableList.builder().add((Object)maskedPartialResult).addAll(originalExpression.getArguments().subList(1, originalExpression.getArguments().size())).build());
                AggregationNode.Aggregation newFinalAggregation = new AggregationNode.Aggregation(newExpression, aggregation.getFilter(), aggregation.getOrderBy(), aggregation.isDistinct(), aggregation.getMask());
                newFinalAggregationMap.put(entry.getKey(), (Object)newFinalAggregation);
            }
            PlanNode projectNode = PlannerUtils.addProjections(rewrittenSource, this.planNodeIdAllocator, this.variableAllocator, (List<RowExpression>)projectionsFromPartialAgg.build(), (List<VariableReferenceExpression>)variablesForPartialAggResult.build());
            context.get().clear();
            return new AggregationNode(node.getSourceLocation(), node.getId(), projectNode, (Map)newFinalAggregationMap.build(), node.getGroupingSets(), node.getPreGroupedVariables(), node.getStep(), node.getHashVariable(), node.getGroupIdVariable());
        }

        public PlanNode visitProject(ProjectNode node, SimplePlanRewriter.RewriteContext<Context> context) {
            PlanNode rewrittenSource = context.rewrite(node.getSource(), context.get());
            if (!context.get().isEmpty()) {
                Assignments.Builder assignments = Assignments.builder();
                Map excludeMergedAssignments = (Map)node.getAssignments().getMap().entrySet().stream().filter(x -> !(x.getValue() instanceof VariableReferenceExpression) || !((Context)context.get()).getPartialOutputMapping().containsKey(x.getValue())).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
                assignments.putAll(excludeMergedAssignments);
                assignments.putAll(AssignmentUtils.identityAssignments(context.get().getPartialResultToMask().values()));
                return new ProjectNode(node.getSourceLocation(), node.getId(), rewrittenSource, assignments.build(), node.getLocality());
            }
            return node.replaceChildren((List)ImmutableList.of((Object)rewrittenSource));
        }

        @Override
        public PlanNode visitExchange(ExchangeNode node, SimplePlanRewriter.RewriteContext<Context> context) {
            ImmutableList.Builder rewriteChildren = ImmutableList.builder();
            for (PlanNode child : node.getSources()) {
                context.get().clear();
                rewriteChildren.add((Object)context.rewrite(child, context.get()));
            }
            ImmutableList children = rewriteChildren.build();
            if (!context.get().isEmpty()) {
                PartitioningScheme partitioning = new PartitioningScheme(node.getPartitioningScheme().getPartitioning(), ((PlanNode)children.get(children.size() - 1)).getOutputVariables(), node.getPartitioningScheme().getHashColumn(), node.getPartitioningScheme().isReplicateNullsAndAny(), node.getPartitioningScheme().getBucketToPartition());
                return new ExchangeNode(node.getSourceLocation(), node.getId(), node.getType(), node.getScope(), partitioning, (List<PlanNode>)children, (List)children.stream().map(x -> x.getOutputVariables()).collect(ImmutableList.toImmutableList()), node.isEnsureSourceOrdering(), node.getOrderingScheme());
            }
            return node.replaceChildren((List<PlanNode>)children);
        }
    }

    private static class Context {
        private final Map<VariableReferenceExpression, VariableReferenceExpression> partialResultToMask = new HashMap<VariableReferenceExpression, VariableReferenceExpression>();
        private final Map<VariableReferenceExpression, VariableReferenceExpression> partialOutputMapping = new HashMap<VariableReferenceExpression, VariableReferenceExpression>();

        public boolean isEmpty() {
            return this.partialOutputMapping.isEmpty();
        }

        public void clear() {
            this.partialResultToMask.clear();
            this.partialOutputMapping.clear();
        }

        public Map<VariableReferenceExpression, VariableReferenceExpression> getPartialOutputMapping() {
            return this.partialOutputMapping;
        }

        public Map<VariableReferenceExpression, VariableReferenceExpression> getPartialResultToMask() {
            return this.partialResultToMask;
        }
    }
}

