/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.CatalogSchemaName;
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.MarkDistinctNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizerResult;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

public class OptimizeMixedDistinctAggregations
implements PlanOptimizer {
    private final Metadata metadata;
    private final StandardFunctionResolution functionResolution;
    private boolean isEnabledForTesting;

    public OptimizeMixedDistinctAggregations(Metadata metadata) {
        this.metadata = metadata;
        this.functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver());
    }

    @Override
    public void setEnabledForTesting(boolean isSet) {
        this.isEnabledForTesting = isSet;
    }

    @Override
    public boolean isEnabled(Session session) {
        return this.isEnabledForTesting || SystemSessionProperties.isOptimizeDistinctAggregationEnabled(session);
    }

    @Override
    public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) {
        if (this.isEnabled(session)) {
            Optimizer optimizer = new Optimizer(idAllocator, variableAllocator, this.metadata, this.functionResolution);
            PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(optimizer, plan, Optional.empty());
            return PlanOptimizerResult.optimizerResult(rewrittenPlan, optimizer.isPlanChanged());
        }
        return PlanOptimizerResult.optimizerResult(plan, false);
    }

    private static class AggregateInfo {
        private final List<VariableReferenceExpression> groupByVariables;
        private final VariableReferenceExpression mask;
        private final Map<VariableReferenceExpression, AggregationNode.Aggregation> aggregations;
        private Map<VariableReferenceExpression, VariableReferenceExpression> newNonDistinctAggregateVariables;
        private VariableReferenceExpression newDistinctAggregateVariable;
        private boolean foundMarkDistinct;

        public AggregateInfo(List<VariableReferenceExpression> groupByVariables, VariableReferenceExpression mask, Map<VariableReferenceExpression, AggregationNode.Aggregation> aggregations) {
            this.groupByVariables = ImmutableList.copyOf(groupByVariables);
            this.mask = mask;
            this.aggregations = ImmutableMap.copyOf(aggregations);
        }

        public List<VariableReferenceExpression> getOriginalNonDistinctAggregateArgs() {
            return this.aggregations.values().stream().filter(aggregation -> !aggregation.getMask().isPresent()).flatMap(aggregation -> aggregation.getArguments().stream()).distinct().map(VariableReferenceExpression.class::cast).collect(Collectors.toList());
        }

        public List<VariableReferenceExpression> getOriginalDistinctAggregateArgs() {
            return this.aggregations.values().stream().filter(aggregation -> aggregation.getMask().isPresent()).flatMap(aggregation -> aggregation.getArguments().stream()).distinct().map(VariableReferenceExpression.class::cast).collect(Collectors.toList());
        }

        public VariableReferenceExpression getNewDistinctAggregateVariable() {
            return this.newDistinctAggregateVariable;
        }

        public void setNewDistinctAggregateSymbol(VariableReferenceExpression newDistinctAggregateVariable) {
            this.newDistinctAggregateVariable = newDistinctAggregateVariable;
        }

        public Map<VariableReferenceExpression, VariableReferenceExpression> getNewNonDistinctAggregateVariables() {
            return this.newNonDistinctAggregateVariables;
        }

        public void setNewNonDistinctAggregateSymbols(Map<VariableReferenceExpression, VariableReferenceExpression> newNonDistinctAggregateVariables) {
            this.newNonDistinctAggregateVariables = newNonDistinctAggregateVariables;
        }

        public VariableReferenceExpression getMask() {
            return this.mask;
        }

        public List<VariableReferenceExpression> getGroupByVariables() {
            return this.groupByVariables;
        }

        public Map<VariableReferenceExpression, AggregationNode.Aggregation> getAggregations() {
            return this.aggregations;
        }

        public void foundMarkDistinct() {
            this.foundMarkDistinct = true;
        }

        public boolean isFoundMarkDistinct() {
            return this.foundMarkDistinct;
        }
    }

    private static class Optimizer
    extends SimplePlanRewriter<Optional<AggregateInfo>> {
        private final PlanNodeIdAllocator idAllocator;
        private final VariableAllocator variableAllocator;
        private final Metadata metadata;
        private final StandardFunctionResolution functionResolution;
        private boolean planChanged;

        private Optimizer(PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Metadata metadata, StandardFunctionResolution functionResolution) {
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.variableAllocator = Objects.requireNonNull(variableAllocator, "variableAllocator is null");
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
            this.functionResolution = Objects.requireNonNull(functionResolution, "functionResolution is null");
        }

        public boolean isPlanChanged() {
            return this.planChanged;
        }

        public PlanNode visitAggregation(AggregationNode node, SimplePlanRewriter.RewriteContext<Optional<AggregateInfo>> context) {
            List masks = (List)node.getAggregations().values().stream().map(AggregationNode.Aggregation::getMask).filter(Optional::isPresent).map(Optional::get).collect(ImmutableList.toImmutableList());
            ImmutableSet uniqueMasks = ImmutableSet.copyOf((Collection)masks);
            if (uniqueMasks.size() != 1 || masks.size() == node.getAggregations().size()) {
                return context.defaultRewrite((PlanNode)node, Optional.empty());
            }
            if (node.getAggregations().values().stream().map(AggregationNode.Aggregation::getFilter).anyMatch(Optional::isPresent)) {
                return context.defaultRewrite((PlanNode)node, Optional.empty());
            }
            if (node.hasOrderings()) {
                return context.defaultRewrite((PlanNode)node, Optional.empty());
            }
            AggregateInfo aggregateInfo = new AggregateInfo(node.getGroupingKeys(), (VariableReferenceExpression)Iterables.getOnlyElement((Iterable)uniqueMasks), node.getAggregations());
            if (!this.checkAllEquatableTypes(aggregateInfo)) {
                return context.defaultRewrite((PlanNode)node, Optional.empty());
            }
            PlanNode source = context.rewrite(node.getSource(), Optional.of(aggregateInfo));
            if (!aggregateInfo.isFoundMarkDistinct()) {
                return context.defaultRewrite((PlanNode)node, Optional.empty());
            }
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            ImmutableMap.Builder coalesceVariablesBuilder = ImmutableMap.builder();
            for (Map.Entry entry : node.getAggregations().entrySet()) {
                if (((AggregationNode.Aggregation)entry.getValue()).getMask().isPresent()) {
                    VariableReferenceExpression input = aggregateInfo.getNewDistinctAggregateVariable();
                    aggregations.put(entry.getKey(), (Object)new AggregationNode.Aggregation(new CallExpression(((AggregationNode.Aggregation)entry.getValue()).getCall().getSourceLocation(), ((AggregationNode.Aggregation)entry.getValue()).getCall().getDisplayName(), ((AggregationNode.Aggregation)entry.getValue()).getCall().getFunctionHandle(), ((AggregationNode.Aggregation)entry.getValue()).getCall().getType(), (List)ImmutableList.of((Object)input)), Optional.empty(), Optional.empty(), false, Optional.empty()));
                    continue;
                }
                VariableReferenceExpression argument = aggregateInfo.getNewNonDistinctAggregateVariables().get(entry.getKey());
                AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation(new CallExpression(((AggregationNode.Aggregation)entry.getValue()).getCall().getSourceLocation(), "arbitrary", this.metadata.getFunctionAndTypeManager().lookupFunction("arbitrary", TypeSignatureProvider.fromTypes((List)ImmutableList.of((Object)argument.getType()))), ((VariableReferenceExpression)entry.getKey()).getType(), (List)ImmutableList.of((Object)argument)), Optional.empty(), Optional.empty(), false, Optional.empty());
                QualifiedObjectName functionName = this.metadata.getFunctionAndTypeManager().getFunctionMetadata(((AggregationNode.Aggregation)entry.getValue()).getFunctionHandle()).getName();
                if (functionName.equals((Object)QualifiedObjectName.valueOf((CatalogSchemaName)BuiltInTypeAndFunctionNamespaceManager.DEFAULT_NAMESPACE, (String)"count")) || functionName.equals((Object)QualifiedObjectName.valueOf((CatalogSchemaName)BuiltInTypeAndFunctionNamespaceManager.DEFAULT_NAMESPACE, (String)"count_if")) || functionName.equals((Object)QualifiedObjectName.valueOf((CatalogSchemaName)BuiltInTypeAndFunctionNamespaceManager.DEFAULT_NAMESPACE, (String)"approx_distinct"))) {
                    VariableReferenceExpression newVariable = this.variableAllocator.newVariable(((AggregationNode.Aggregation)entry.getValue()).getCall().getSourceLocation(), "expr", ((VariableReferenceExpression)entry.getKey()).getType());
                    aggregations.put((Object)newVariable, (Object)aggregation);
                    coalesceVariablesBuilder.put((Object)newVariable, entry.getKey());
                    continue;
                }
                aggregations.put(entry.getKey(), (Object)aggregation);
            }
            ImmutableMap coalesceVariables = coalesceVariablesBuilder.build();
            AggregationNode aggregationNode = new AggregationNode(node.getSourceLocation(), this.idAllocator.getNextId(), source, (Map)aggregations.build(), node.getGroupingSets(), (List)ImmutableList.of(), node.getStep(), Optional.empty(), node.getGroupIdVariable(), node.getAggregationId());
            this.planChanged = true;
            if (coalesceVariables.isEmpty()) {
                return aggregationNode;
            }
            Assignments.Builder outputVariables = Assignments.builder();
            for (VariableReferenceExpression variable : aggregationNode.getOutputVariables()) {
                if (coalesceVariables.containsKey(variable)) {
                    SpecialFormExpression expression = new SpecialFormExpression(variable.getSourceLocation(), SpecialFormExpression.Form.COALESCE, (Type)BigintType.BIGINT, new RowExpression[]{variable, Expressions.constant(0L, (Type)BigintType.BIGINT)});
                    outputVariables.put((VariableReferenceExpression)coalesceVariables.get(variable), (RowExpression)expression);
                    continue;
                }
                outputVariables.put(variable, (RowExpression)variable);
            }
            return new ProjectNode(node.getSourceLocation(), this.idAllocator.getNextId(), (PlanNode)aggregationNode, outputVariables.build(), ProjectNode.Locality.LOCAL);
        }

        public PlanNode visitMarkDistinct(MarkDistinctNode node, SimplePlanRewriter.RewriteContext<Optional<AggregateInfo>> context) {
            VariableReferenceExpression distinctVariable;
            Optional<AggregateInfo> aggregateInfo = context.get();
            if (!aggregateInfo.isPresent() || !aggregateInfo.get().getMask().equals((Object)node.getMarkerVariable())) {
                return context.defaultRewrite((PlanNode)node, Optional.empty());
            }
            aggregateInfo.get().foundMarkDistinct();
            PlanNode source = context.rewrite(node.getSource(), Optional.empty());
            HashSet<VariableReferenceExpression> allVariables = new HashSet<VariableReferenceExpression>();
            List<VariableReferenceExpression> groupByVariables = aggregateInfo.get().getGroupByVariables();
            List<VariableReferenceExpression> nonDistinctAggregateVariables = aggregateInfo.get().getOriginalNonDistinctAggregateArgs();
            VariableReferenceExpression duplicatedDistinctVariable = distinctVariable = (VariableReferenceExpression)Iterables.getOnlyElement(aggregateInfo.get().getOriginalDistinctAggregateArgs());
            if (nonDistinctAggregateVariables.contains(distinctVariable)) {
                VariableReferenceExpression newVariable = this.variableAllocator.newVariable(distinctVariable);
                nonDistinctAggregateVariables.set(nonDistinctAggregateVariables.indexOf(distinctVariable), newVariable);
                duplicatedDistinctVariable = newVariable;
            }
            allVariables.addAll(groupByVariables);
            allVariables.addAll(nonDistinctAggregateVariables);
            allVariables.add(distinctVariable);
            VariableReferenceExpression groupVariable = this.variableAllocator.newVariable("group", (Type)BigintType.BIGINT);
            GroupIdNode groupIdNode = this.createGroupIdNode(groupByVariables, nonDistinctAggregateVariables, distinctVariable, duplicatedDistinctVariable, groupVariable, allVariables, source);
            HashSet<VariableReferenceExpression> groupByKeys = new HashSet<VariableReferenceExpression>(groupByVariables);
            groupByKeys.add(distinctVariable);
            groupByKeys.add(groupVariable);
            ImmutableMap.Builder aggregationOutputVariablesMapBuilder = ImmutableMap.builder();
            AggregationNode aggregationNode = this.createNonDistinctAggregation(aggregateInfo.get(), distinctVariable, duplicatedDistinctVariable, groupByKeys, groupIdNode, node, (ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression>)aggregationOutputVariablesMapBuilder);
            ImmutableMap aggregationOutputVariablesMap = aggregationOutputVariablesMapBuilder.build();
            ProjectNode projectNode = this.createProjectNode(aggregationNode, aggregateInfo.get(), distinctVariable, groupVariable, groupByVariables, (Map<VariableReferenceExpression, VariableReferenceExpression>)aggregationOutputVariablesMap);
            return projectNode;
        }

        private boolean checkAllEquatableTypes(AggregateInfo aggregateInfo) {
            for (VariableReferenceExpression variable : aggregateInfo.getOriginalNonDistinctAggregateArgs()) {
                if (variable.getType().isComparable()) continue;
                return false;
            }
            return aggregateInfo.getMask().getType().isComparable();
        }

        private ProjectNode createProjectNode(AggregationNode source, AggregateInfo aggregateInfo, VariableReferenceExpression distinctVariable, VariableReferenceExpression groupVariable, List<VariableReferenceExpression> groupByVariables, Map<VariableReferenceExpression, VariableReferenceExpression> aggregationOutputVariablesMap) {
            Assignments.Builder outputVariables = Assignments.builder();
            ImmutableMap.Builder outputNonDistinctAggregateVariables = ImmutableMap.builder();
            for (VariableReferenceExpression variable : source.getOutputVariables()) {
                SpecialFormExpression ifExpression;
                VariableReferenceExpression newVariable;
                if (distinctVariable.equals((Object)variable)) {
                    newVariable = this.variableAllocator.newVariable(variable.getSourceLocation(), "expr", variable.getType());
                    aggregateInfo.setNewDistinctAggregateSymbol(newVariable);
                    ifExpression = new SpecialFormExpression(SpecialFormExpression.Form.IF, variable.getType(), (List)ImmutableList.of((Object)Expressions.call(OperatorType.EQUAL.name(), this.functionResolution.comparisonFunction(OperatorType.EQUAL, (Type)BigintType.BIGINT, (Type)BigintType.BIGINT), (Type)BooleanType.BOOLEAN, (List<RowExpression>)ImmutableList.of((Object)groupVariable, (Object)Expressions.constant(1L, (Type)BigintType.BIGINT))), (Object)variable, (Object)Expressions.constantNull(variable.getSourceLocation(), variable.getType())));
                    outputVariables.put(newVariable, (RowExpression)ifExpression);
                } else if (aggregationOutputVariablesMap.containsKey(variable)) {
                    newVariable = this.variableAllocator.newVariable(variable.getSourceLocation(), "expr", variable.getType());
                    outputNonDistinctAggregateVariables.put((Object)aggregationOutputVariablesMap.get(variable), (Object)newVariable);
                    ifExpression = new SpecialFormExpression(SpecialFormExpression.Form.IF, variable.getType(), (List)ImmutableList.of((Object)Expressions.call(OperatorType.EQUAL.name(), this.functionResolution.comparisonFunction(OperatorType.EQUAL, (Type)BigintType.BIGINT, (Type)BigintType.BIGINT), (Type)BooleanType.BOOLEAN, (List<RowExpression>)ImmutableList.of((Object)groupVariable, (Object)Expressions.constant(0L, (Type)BigintType.BIGINT))), (Object)variable, (Object)Expressions.constantNull(variable.getSourceLocation(), variable.getType())));
                    outputVariables.put(newVariable, (RowExpression)ifExpression);
                }
                if (!groupByVariables.contains(variable)) continue;
                outputVariables.put(variable, (RowExpression)variable);
            }
            outputVariables.put(aggregateInfo.getMask(), (RowExpression)Expressions.constantNull(aggregateInfo.getMask().getType()));
            aggregateInfo.setNewNonDistinctAggregateSymbols((Map<VariableReferenceExpression, VariableReferenceExpression>)outputNonDistinctAggregateVariables.build());
            return new ProjectNode(source.getSourceLocation(), this.idAllocator.getNextId(), (PlanNode)source, outputVariables.build(), ProjectNode.Locality.LOCAL);
        }

        private GroupIdNode createGroupIdNode(List<VariableReferenceExpression> groupByVariables, List<VariableReferenceExpression> nonDistinctAggregateVariables, VariableReferenceExpression distinctVariable, VariableReferenceExpression duplicatedDistinctVariable, VariableReferenceExpression groupVariable, Set<VariableReferenceExpression> allVariables, PlanNode source) {
            ArrayList<List<VariableReferenceExpression>> groups = new ArrayList<List<VariableReferenceExpression>>();
            HashSet<VariableReferenceExpression> group0 = new HashSet<VariableReferenceExpression>();
            group0.addAll(groupByVariables);
            group0.addAll(nonDistinctAggregateVariables);
            groups.add((List<VariableReferenceExpression>)ImmutableList.copyOf(group0));
            HashSet<VariableReferenceExpression> group1 = new HashSet<VariableReferenceExpression>(groupByVariables);
            group1.add(distinctVariable);
            groups.add((List<VariableReferenceExpression>)ImmutableList.copyOf(group1));
            return new GroupIdNode(source.getSourceLocation(), this.idAllocator.getNextId(), source, groups, allVariables.stream().collect(Collectors.toMap(Function.identity(), variable -> variable.equals((Object)duplicatedDistinctVariable) ? distinctVariable : variable)), (List<VariableReferenceExpression>)ImmutableList.of(), groupVariable);
        }

        private AggregationNode createNonDistinctAggregation(AggregateInfo aggregateInfo, VariableReferenceExpression distinctVariable, VariableReferenceExpression duplicatedDistinctVariable, Set<VariableReferenceExpression> groupByKeys, GroupIdNode groupIdNode, MarkDistinctNode originalNode, ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression> aggregationOutputSymbolsMapBuilder) {
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : aggregateInfo.getAggregations().entrySet()) {
                List arguments;
                if (entry.getValue().getMask().isPresent()) continue;
                VariableReferenceExpression newVariable = this.variableAllocator.newVariable(entry.getKey());
                AggregationNode.Aggregation aggregation = entry.getValue();
                aggregationOutputSymbolsMapBuilder.put((Object)newVariable, (Object)entry.getKey());
                if (!duplicatedDistinctVariable.equals((Object)distinctVariable) && Optimizer.extractVariables(entry.getValue().getArguments(), TypeProvider.viewOf(this.variableAllocator.getVariables())).contains(distinctVariable)) {
                    ImmutableList.Builder argumentsBuilder = ImmutableList.builder();
                    for (RowExpression argument : aggregation.getArguments()) {
                        if (argument instanceof VariableReferenceExpression && argument.equals((Object)distinctVariable)) {
                            argumentsBuilder.add((Object)duplicatedDistinctVariable);
                            continue;
                        }
                        argumentsBuilder.add((Object)argument);
                    }
                    arguments = argumentsBuilder.build();
                } else {
                    arguments = aggregation.getArguments();
                }
                aggregations.put((Object)newVariable, (Object)new AggregationNode.Aggregation(new CallExpression(aggregation.getCall().getSourceLocation(), aggregation.getCall().getDisplayName(), aggregation.getCall().getFunctionHandle(), aggregation.getCall().getType(), arguments), Optional.empty(), Optional.empty(), false, Optional.empty()));
            }
            return new AggregationNode(groupIdNode.getSourceLocation(), this.idAllocator.getNextId(), (PlanNode)groupIdNode, (Map)aggregations.build(), AggregationNode.singleGroupingSet((List)ImmutableList.copyOf(groupByKeys)), (List)ImmutableList.of(), AggregationNode.Step.SINGLE, originalNode.getHashVariable(), Optional.empty(), Optional.empty());
        }

        private static Set<VariableReferenceExpression> extractVariables(List<RowExpression> arguments, TypeProvider types) {
            ImmutableSet.Builder builder = ImmutableSet.builder();
            for (RowExpression argument : arguments) {
                if (!(argument instanceof VariableReferenceExpression)) continue;
                builder.add((Object)((VariableReferenceExpression)argument));
            }
            return builder.build();
        }
    }
}

