/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.tree.Expression;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class SingleDistinctAggregationToGroupBy
implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(SingleDistinctAggregationToGroupBy::hasSingleDistinctInput).matching(SingleDistinctAggregationToGroupBy::allDistinctAggregates).matching(SingleDistinctAggregationToGroupBy::noFilters).matching(SingleDistinctAggregationToGroupBy::noMasks);

    private static boolean hasSingleDistinctInput(AggregationNode aggregationNode) {
        return SingleDistinctAggregationToGroupBy.extractArgumentSets(aggregationNode).count() == 1L;
    }

    private static boolean allDistinctAggregates(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().allMatch(AggregationNode.Aggregation::isDistinct);
    }

    private static boolean noFilters(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().noneMatch(aggregation -> aggregation.getFilter().isPresent());
    }

    private static boolean noMasks(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().noneMatch(aggregation -> aggregation.getMask().isPresent());
    }

    private static Stream<Set<Expression>> extractArgumentSets(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().filter(AggregationNode.Aggregation::isDistinct).map(AggregationNode.Aggregation::getArguments).map(HashSet::new).distinct();
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(AggregationNode aggregation, Captures captures, Rule.Context context) {
        List argumentSets = SingleDistinctAggregationToGroupBy.extractArgumentSets(aggregation).collect(Collectors.toList());
        Set symbols = ((Set)Iterables.getOnlyElement(argumentSets)).stream().map(Symbol::from).collect(Collectors.toSet());
        return Rule.Result.ofPlanNode(new AggregationNode(aggregation.getId(), new AggregationNode(context.getIdAllocator().getNextId(), aggregation.getSource(), (Map<Symbol, AggregationNode.Aggregation>)ImmutableMap.of(), AggregationNode.singleGroupingSet((List<Symbol>)ImmutableList.builder().addAll(aggregation.getGroupingKeys()).addAll(symbols).build()), (List<Symbol>)ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), aggregation.getAggregations().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> SingleDistinctAggregationToGroupBy.removeDistinct((AggregationNode.Aggregation)e.getValue()))), aggregation.getGroupingSets(), Collections.emptyList(), aggregation.getStep(), aggregation.getHashSymbol(), aggregation.getGroupIdSymbol()));
    }

    private static AggregationNode.Aggregation removeDistinct(AggregationNode.Aggregation aggregation) {
        Preconditions.checkArgument((boolean)aggregation.isDistinct(), (Object)"Expected aggregation to have DISTINCT input");
        return new AggregationNode.Aggregation(aggregation.getResolvedFunction(), aggregation.getArguments(), false, aggregation.getFilter(), aggregation.getOrderingScheme(), aggregation.getMask());
    }
}

