/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.expressions.DefaultRowExpressionTraversalVisitor;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.RowExpressionVariableInliner;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

public class RewriteIfOverAggregation
implements PlanOptimizer {
    private final FunctionAndTypeManager functionAndTypeManager;

    public RewriteIfOverAggregation(FunctionAndTypeManager functionAndTypeManager) {
        this.functionAndTypeManager = functionAndTypeManager;
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) {
        if (SystemSessionProperties.isOptimizeConditionalAggregationEnabled(session)) {
            return SimplePlanRewriter.rewriteWith(new Rewriter(variableAllocator, idAllocator, new RowExpressionDeterminismEvaluator(this.functionAndTypeManager)), plan, ImmutableMap.of());
        }
        return plan;
    }

    private static class Rewriter
    extends SimplePlanRewriter<Map<VariableReferenceExpression, RowExpression>> {
        private final VariableAllocator planVariableAllocator;
        private final PlanNodeIdAllocator planNodeIdAllocator;
        private final RowExpressionDeterminismEvaluator determinismEvaluator;

        private Rewriter(VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, RowExpressionDeterminismEvaluator determinismEvaluator) {
            this.planVariableAllocator = variableAllocator;
            this.planNodeIdAllocator = idAllocator;
            this.determinismEvaluator = determinismEvaluator;
        }

        private static VariableReferenceExpression getTrueValueFromIf(RowExpression rowExpression) {
            Preconditions.checkState((rowExpression instanceof SpecialFormExpression && ((SpecialFormExpression)rowExpression).getArguments().get(1) instanceof VariableReferenceExpression ? 1 : 0) != 0);
            return (VariableReferenceExpression)((SpecialFormExpression)rowExpression).getArguments().get(1);
        }

        private static RowExpression inlineReferences(RowExpression expression, Assignments assignments) {
            return RowExpressionVariableInliner.inlineVariables(variable -> (RowExpression)assignments.getMap().getOrDefault(variable, variable), expression);
        }

        @Override
        public PlanNode visitPlan(PlanNode node, SimplePlanRewriter.RewriteContext<Map<VariableReferenceExpression, RowExpression>> context) {
            return context.defaultRewrite(node, (Map<VariableReferenceExpression, RowExpression>)ImmutableMap.of());
        }

        public PlanNode visitProject(ProjectNode node, SimplePlanRewriter.RewriteContext<Map<VariableReferenceExpression, RowExpression>> context) {
            Set candidateVariables = (Set)node.getAssignments().getExpressions().stream().flatMap(expression -> VariablesExtractor.extractAll(expression).stream()).collect(Collectors.groupingBy(Function.identity(), Collectors.counting())).entrySet().stream().filter(entry -> (Long)entry.getValue() == 1L).map(Map.Entry::getKey).collect(ImmutableSet.toImmutableSet());
            ImmutableSet.Builder candidateIfBuilder = ImmutableSet.builder();
            IfExpressionExtractor ifExpressionExtractor = new IfExpressionExtractor();
            node.getAssignments().getExpressions().forEach(expression -> {
                Void cfr_ignored_0 = (Void)expression.accept((RowExpressionVisitor)ifExpressionExtractor, (Object)candidateIfBuilder);
            });
            Map candidatesInAssignments = (Map)candidateIfBuilder.build().stream().filter(x -> candidateVariables.contains(Rewriter.getTrueValueFromIf(x))).collect(ImmutableMap.toImmutableMap(Rewriter::getTrueValueFromIf, Function.identity()));
            Map candidatePassedFromContext = (Map)context.get().entrySet().stream().filter(x -> candidateVariables.contains(x.getKey())).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, x -> Rewriter.inlineReferences((RowExpression)x.getValue(), node.getAssignments())));
            ImmutableMap.Builder candidates = ImmutableMap.builder();
            candidates.putAll(candidatesInAssignments);
            candidates.putAll(candidatePassedFromContext);
            return context.defaultRewrite((PlanNode)node, (Map<VariableReferenceExpression, RowExpression>)candidates.build());
        }

        public PlanNode visitAggregation(AggregationNode node, SimplePlanRewriter.RewriteContext<Map<VariableReferenceExpression, RowExpression>> context) {
            Map candidate = (Map)context.get().entrySet().stream().filter(x -> node.getAggregations().containsKey(x.getKey())).filter(x -> VariablesExtractor.extractUnique((RowExpression)x.getValue()).stream().filter(variable -> !variable.equals(x.getKey())).allMatch(node.getSource().getOutputVariables()::contains)).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
            if (candidate.isEmpty()) {
                return context.defaultRewrite((PlanNode)node, (Map<VariableReferenceExpression, RowExpression>)ImmutableMap.of());
            }
            Assignments.Builder sourceProjection = Assignments.builder();
            ImmutableMap.Builder newAggregations = ImmutableMap.builder();
            candidate.forEach((aggregationOutput, ifExpression) -> {
                Preconditions.checkState((ifExpression instanceof SpecialFormExpression && ((SpecialFormExpression)ifExpression).getForm().equals((Object)SpecialFormExpression.Form.IF) ? 1 : 0) != 0);
                RowExpression condition = (RowExpression)((SpecialFormExpression)ifExpression).getArguments().get(0);
                AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation)node.getAggregations().get(aggregationOutput);
                RowExpression maskExpression = aggregation.getMask().isPresent() ? LogicalRowExpressions.and((RowExpression[])new RowExpression[]{(RowExpression)aggregation.getMask().get(), condition}) : condition;
                VariableReferenceExpression maskVariable = this.planVariableAllocator.newVariable(maskExpression);
                AggregationNode.Aggregation newAggregation = new AggregationNode.Aggregation(aggregation.getCall(), aggregation.getFilter(), aggregation.getOrderBy(), aggregation.isDistinct(), Optional.of(maskVariable));
                sourceProjection.put(maskVariable, maskExpression);
                newAggregations.put(aggregationOutput, (Object)newAggregation);
            });
            sourceProjection.putAll((Map)node.getSource().getOutputVariables().stream().collect(ImmutableMap.toImmutableMap(Function.identity(), Function.identity())));
            newAggregations.putAll((Map)node.getAggregations().entrySet().stream().filter(x -> !candidate.containsKey(x.getKey())).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)));
            AggregationNode aggregationNode = new AggregationNode(node.getSourceLocation(), node.getId(), (PlanNode)new ProjectNode(this.planNodeIdAllocator.getNextId(), node.getSource(), sourceProjection.build()), (Map)newAggregations.build(), node.getGroupingSets(), node.getPreGroupedVariables(), node.getStep(), node.getHashVariable(), node.getGroupIdVariable());
            return context.defaultRewrite((PlanNode)aggregationNode, (Map<VariableReferenceExpression, RowExpression>)ImmutableMap.of());
        }

        private boolean isCandidateIfExpression(RowExpression rowExpression) {
            return this.determinismEvaluator.isDeterministic(rowExpression) && rowExpression instanceof SpecialFormExpression && ((SpecialFormExpression)rowExpression).getForm().equals((Object)SpecialFormExpression.Form.IF) && ((SpecialFormExpression)rowExpression).getArguments().get(1) instanceof VariableReferenceExpression && (((SpecialFormExpression)rowExpression).getArguments().size() == 2 || ((SpecialFormExpression)rowExpression).getArguments().get(2) instanceof ConstantExpression && ((ConstantExpression)((SpecialFormExpression)rowExpression).getArguments().get(2)).isNull());
        }

        private class IfExpressionExtractor
        extends DefaultRowExpressionTraversalVisitor<ImmutableSet.Builder<RowExpression>> {
            private IfExpressionExtractor() {
            }

            public Void visitSpecialForm(SpecialFormExpression specialForm, ImmutableSet.Builder<RowExpression> context) {
                if (Rewriter.this.isCandidateIfExpression((RowExpression)specialForm)) {
                    context.add((Object)specialForm);
                }
                return super.visitSpecialForm(specialForm, context);
            }
        }
    }
}

