/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.MarkDistinctNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;

public class MultipleDistinctAggregationToMarkDistinct
implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching((Predicate)Predicates.and(MultipleDistinctAggregationToMarkDistinct::hasNoDistinctWithFilterOrMask, (com.google.common.base.Predicate)Predicates.or(MultipleDistinctAggregationToMarkDistinct::hasMultipleDistincts, MultipleDistinctAggregationToMarkDistinct::hasMixedDistinctAndNonDistincts)));

    private static boolean hasNoDistinctWithFilterOrMask(AggregationNode aggregation) {
        return aggregation.getAggregations().values().stream().noneMatch(e -> e.isDistinct() && (e.getFilter().isPresent() || e.getMask().isPresent()));
    }

    private static boolean hasMultipleDistincts(AggregationNode aggregation) {
        return aggregation.getAggregations().values().stream().filter(e -> e.isDistinct()).map(AggregationNode.Aggregation::getArguments).map(HashSet::new).distinct().count() > 1L;
    }

    private static boolean hasMixedDistinctAndNonDistincts(AggregationNode aggregation) {
        long distincts = aggregation.getAggregations().values().stream().filter(AggregationNode.Aggregation::isDistinct).count();
        return distincts > 0L && distincts < (long)aggregation.getAggregations().size();
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(AggregationNode parent, Captures captures, Rule.Context context) {
        if (!SystemSessionProperties.useMarkDistinct(context.getSession())) {
            return Rule.Result.empty();
        }
        HashMap markers = new HashMap();
        HashMap newAggregations = new HashMap();
        PlanNode subPlan = parent.getSource();
        for (Map.Entry entry : parent.getAggregations().entrySet()) {
            AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation)entry.getValue();
            if (aggregation.isDistinct() && !aggregation.getFilter().isPresent() && !aggregation.getMask().isPresent()) {
                Set inputs = aggregation.getArguments().stream().map(VariableReferenceExpression.class::cast).collect(Collectors.toSet());
                VariableReferenceExpression marker = (VariableReferenceExpression)markers.get(inputs);
                if (marker == null) {
                    marker = context.getVariableAllocator().newVariable(((VariableReferenceExpression)Iterables.getLast(inputs)).getName(), (Type)BooleanType.BOOLEAN, "distinct");
                    markers.put(inputs, marker);
                    ImmutableSet.Builder distinctVariables = ImmutableSet.builder().addAll((Iterable)parent.getGroupingKeys()).addAll(inputs);
                    parent.getGroupIdVariable().ifPresent(arg_0 -> ((ImmutableSet.Builder)distinctVariables).add(arg_0));
                    subPlan = new MarkDistinctNode(subPlan.getSourceLocation(), context.getIdAllocator().getNextId(), subPlan, marker, (List)ImmutableList.copyOf((Collection)distinctVariables.build()), Optional.empty());
                }
                newAggregations.put(entry.getKey(), new AggregationNode.Aggregation(aggregation.getCall(), aggregation.getFilter(), aggregation.getOrderBy(), false, Optional.of(marker)));
                continue;
            }
            newAggregations.put(entry.getKey(), aggregation);
        }
        return Rule.Result.ofPlanNode((PlanNode)new AggregationNode(parent.getSourceLocation(), parent.getId(), subPlan, newAggregations, parent.getGroupingSets(), (List)ImmutableList.of(), parent.getStep(), parent.getHashVariable(), parent.getGroupIdVariable()));
    }
}

