/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.SystemSessionProperties;
import io.trino.cost.TaskCountEstimator;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.FunctionResolver;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.ir.Coalesce;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.planner.OptimizerConfig;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.DistinctAggregationStrategyChooser;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.GroupIdNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.QualifiedName;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;

public class OptimizeMixedDistinctAggregations
implements Rule<AggregationNode> {
    private static final CatalogSchemaFunctionName COUNT_NAME = GlobalFunctionCatalog.builtinFunctionName("count");
    private static final CatalogSchemaFunctionName COUNT_IF_NAME = GlobalFunctionCatalog.builtinFunctionName("count_if");
    private static final CatalogSchemaFunctionName APPROX_DISTINCT_NAME = GlobalFunctionCatalog.builtinFunctionName("approx_distinct");
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching((Predicate)Predicates.and((com.google.common.base.Predicate[])new com.google.common.base.Predicate[]{Predicates.or(OptimizeMixedDistinctAggregations::hasMultipleDistincts, OptimizeMixedDistinctAggregations::hasMixedDistinctAndNonDistincts), OptimizeMixedDistinctAggregations::allDistinctAggregationsHaveSingleArgument, OptimizeMixedDistinctAggregations::noFilters, OptimizeMixedDistinctAggregations::noMasks, aggregation -> !aggregation.hasOrderings(), aggregation -> aggregation.getStep().equals((Object)AggregationNode.Step.SINGLE)}));
    private final FunctionResolver functionResolver;
    private final DistinctAggregationStrategyChooser distinctAggregationStrategyChooser;

    private static boolean hasMultipleDistincts(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().filter(AggregationNode.Aggregation::isDistinct).map(AggregationNode.Aggregation::getArguments).map(HashSet::new).distinct().count() > 1L;
    }

    private static boolean hasMixedDistinctAndNonDistincts(AggregationNode aggregationNode) {
        long distincts = aggregationNode.getAggregations().values().stream().filter(AggregationNode.Aggregation::isDistinct).count();
        return distincts > 0L && distincts < (long)aggregationNode.getAggregations().size();
    }

    private static boolean allDistinctAggregationsHaveSingleArgument(AggregationNode aggregation) {
        return aggregation.getAggregations().values().stream().filter(AggregationNode.Aggregation::isDistinct).allMatch(node -> node.getArguments().size() == 1);
    }

    private static boolean noFilters(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().noneMatch(aggregation -> aggregation.getFilter().isPresent());
    }

    private static boolean noMasks(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().noneMatch(aggregation -> aggregation.getMask().isPresent());
    }

    public OptimizeMixedDistinctAggregations(PlannerContext plannerContext, TaskCountEstimator taskCountEstimator) {
        this.functionResolver = plannerContext.getFunctionResolver();
        this.distinctAggregationStrategyChooser = DistinctAggregationStrategyChooser.createDistinctAggregationStrategyChooser(taskCountEstimator);
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(AggregationNode node, Captures captures, Rule.Context context) {
        OptimizerConfig.DistinctAggregationsStrategy distinctAggregationsStrategy = SystemSessionProperties.distinctAggregationsStrategy(context.getSession());
        if (!(distinctAggregationsStrategy.equals((Object)OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE) || distinctAggregationsStrategy.equals((Object)OptimizerConfig.DistinctAggregationsStrategy.AUTOMATIC) && this.distinctAggregationStrategyChooser.shouldUsePreAggregate(node, context.getSession(), context.getStatsProvider()))) {
            return Rule.Result.empty();
        }
        SymbolAllocator symbolAllocator = context.getSymbolAllocator();
        Set originalDistinctAggregationArguments = (Set)node.getAggregations().values().stream().filter(AggregationNode.Aggregation::isDistinct).flatMap(aggregation -> aggregation.getArguments().stream()).map(Symbol::from).collect(ImmutableSet.toImmutableSet());
        boolean hasNonDistinctAggregation = node.getAggregations().values().stream().anyMatch(aggregation -> !aggregation.isDistinct());
        ImmutableMap.Builder distinctAggregationArgumentToGroupIdMapBuilder = ImmutableMap.builder();
        int distinctGroupId = hasNonDistinctAggregation ? 1 : 0;
        for (Symbol distinctAggregationInput : originalDistinctAggregationArguments) {
            distinctAggregationArgumentToGroupIdMapBuilder.put((Object)distinctAggregationInput, (Object)distinctGroupId++);
        }
        ImmutableMap distinctAggregationArgumentToGroupIdMap = distinctAggregationArgumentToGroupIdMapBuilder.buildOrThrow();
        ImmutableMap.Builder groupIdOutputToInputColumnMapping = ImmutableMap.builder();
        for (Symbol distinctArgumentSymbol : originalDistinctAggregationArguments) {
            groupIdOutputToInputColumnMapping.put((Object)distinctArgumentSymbol, (Object)distinctArgumentSymbol);
        }
        for (Symbol groupingKey : node.getGroupingKeys()) {
            groupIdOutputToInputColumnMapping.put((Object)groupingKey, (Object)groupingKey);
        }
        Symbol groupSymbol = symbolAllocator.newSymbol("group", (Type)BigintType.BIGINT);
        Assignments.Builder groupIdFilters = Assignments.builder();
        Symbol nonDistinctGroupFilterSymbol = symbolAllocator.newSymbol("non-distinct-gid-filter", (Type)BooleanType.BOOLEAN);
        if (hasNonDistinctAggregation) {
            groupIdFilters.put(nonDistinctGroupFilterSymbol, new Comparison(Comparison.Operator.EQUAL, groupSymbol.toSymbolReference(), new Constant((Type)BigintType.BIGINT, 0L)));
        }
        ImmutableMap.Builder outerAggregations = ImmutableMap.builder();
        HashMap<Integer, Symbol> groupIdFilterSymbolByGroupId = new HashMap<Integer, Symbol>();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
            AggregationNode.Aggregation originalAggregation = entry.getValue();
            if (!originalAggregation.isDistinct()) continue;
            Symbol aggregationInput = Symbol.from(originalAggregation.getArguments().get(0));
            Integer groupId = (Integer)distinctAggregationArgumentToGroupIdMap.get(aggregationInput);
            Symbol groupIdFilterSymbol = groupIdFilterSymbolByGroupId.computeIfAbsent(groupId, id -> {
                Symbol filterSymbol = symbolAllocator.newSymbol("gid-filter-" + groupId, (Type)BooleanType.BOOLEAN);
                groupIdFilters.put(filterSymbol, new Comparison(Comparison.Operator.EQUAL, groupSymbol.toSymbolReference(), new Constant((Type)BigintType.BIGINT, groupId)));
                return filterSymbol;
            });
            outerAggregations.put((Object)entry.getKey(), (Object)new AggregationNode.Aggregation(originalAggregation.getResolvedFunction(), originalAggregation.getArguments(), false, Optional.of(groupIdFilterSymbol), Optional.empty(), Optional.empty()));
        }
        ImmutableMap.Builder innerAggregations = ImmutableMap.builder();
        ImmutableMap.Builder coalesceSymbolsBuilder = ImmutableMap.builder();
        ImmutableSet.Builder nonDistinctAggregationArguments = ImmutableSet.builder();
        HashMap<Symbol, Symbol> duplicatedGroupIdInputToOutput = new HashMap<Symbol, Symbol>();
        if (hasNonDistinctAggregation) {
            for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
                AggregationNode.Aggregation originalAggregation = entry.getValue();
                if (originalAggregation.isDistinct()) continue;
                Symbol origalAggregationOutputSymbol = entry.getKey();
                ImmutableList.Builder mappedArguments = ImmutableList.builder();
                for (Expression argument : originalAggregation.getArguments()) {
                    Symbol argumentSymbol;
                    Symbol finalArgumentSymbol = argumentSymbol = Symbol.from(argument);
                    if (originalDistinctAggregationArguments.contains(argumentSymbol)) {
                        finalArgumentSymbol = duplicatedGroupIdInputToOutput.computeIfAbsent(argumentSymbol, symbol -> symbolAllocator.newSymbol("gid-non-distinct", symbol.type()));
                    }
                    groupIdOutputToInputColumnMapping.put((Object)finalArgumentSymbol, (Object)argumentSymbol);
                    mappedArguments.add((Object)finalArgumentSymbol.toSymbolReference());
                    nonDistinctAggregationArguments.add((Object)finalArgumentSymbol);
                }
                Symbol innerAggregationOutputSymbol = symbolAllocator.newSymbol("inner", origalAggregationOutputSymbol.type());
                AggregationNode.Aggregation innerAggregation = new AggregationNode.Aggregation(originalAggregation.getResolvedFunction(), (List<Expression>)mappedArguments.build(), false, Optional.empty(), Optional.empty(), Optional.empty());
                innerAggregations.put((Object)innerAggregationOutputSymbol, (Object)innerAggregation);
                AggregationNode.Aggregation outerAggregation = new AggregationNode.Aggregation(this.functionResolver.resolveFunction(context.getSession(), QualifiedName.of((String)"arbitrary"), TypeSignatureProvider.fromTypes(originalAggregation.getResolvedFunction().signature().getReturnType()), new AllowAllAccessControl()), (List<Expression>)ImmutableList.of((Object)innerAggregationOutputSymbol.toSymbolReference()), false, Optional.of(nonDistinctGroupFilterSymbol), Optional.empty(), Optional.empty());
                Symbol outerAggregationOutputSymbol = origalAggregationOutputSymbol;
                CatalogSchemaFunctionName name = originalAggregation.getResolvedFunction().signature().getName();
                if (name.equals((Object)COUNT_NAME) || name.equals((Object)COUNT_IF_NAME) || name.equals((Object)APPROX_DISTINCT_NAME)) {
                    Symbol coalesceSymbol;
                    outerAggregationOutputSymbol = coalesceSymbol = symbolAllocator.newSymbol("coalesce_expr", origalAggregationOutputSymbol.type());
                    coalesceSymbolsBuilder.put((Object)coalesceSymbol, (Object)origalAggregationOutputSymbol);
                }
                outerAggregations.put((Object)outerAggregationOutputSymbol, (Object)outerAggregation);
            }
        }
        GroupIdNode groupIdNode = new GroupIdNode(context.getIdAllocator().getNextId(), node.getSource(), OptimizeMixedDistinctAggregations.createGroups(node.getGroupingKeys(), (Set<Symbol>)nonDistinctAggregationArguments.build(), hasNonDistinctAggregation, (Map<Symbol, Integer>)distinctAggregationArgumentToGroupIdMap), (Map<Symbol, Symbol>)groupIdOutputToInputColumnMapping.buildKeepingLast(), (List<Symbol>)ImmutableList.of(), groupSymbol);
        ImmutableSet innerAggregationGropingKeys = ImmutableSet.builder().addAll(node.getGroupingKeys()).addAll((Iterable)originalDistinctAggregationArguments).add((Object)groupSymbol).build();
        AggregationNode innerAggregationNode = new AggregationNode(context.getIdAllocator().getNextId(), groupIdNode, (Map<Symbol, AggregationNode.Aggregation>)innerAggregations.buildOrThrow(), AggregationNode.singleGroupingSet((List<Symbol>)ImmutableList.copyOf((Collection)innerAggregationGropingKeys)), (List<Symbol>)ImmutableList.of(), AggregationNode.Step.SINGLE, node.getHashSymbol(), Optional.empty());
        groupIdFilters.putIdentities(innerAggregationNode.getOutputSymbols());
        ProjectNode groupIdFiltersProjectNode = new ProjectNode(context.getIdAllocator().getNextId(), innerAggregationNode, groupIdFilters.build());
        AggregationNode outerAggregationNode = new AggregationNode(context.getIdAllocator().getNextId(), groupIdFiltersProjectNode, (Map<Symbol, AggregationNode.Aggregation>)outerAggregations.buildOrThrow(), node.getGroupingSets(), (List<Symbol>)ImmutableList.of(), node.getStep(), Optional.empty(), node.getGroupIdSymbol());
        ImmutableMap coalesceSymbols = coalesceSymbolsBuilder.buildOrThrow();
        if (coalesceSymbols.isEmpty()) {
            return Rule.Result.ofPlanNode(outerAggregationNode);
        }
        Assignments.Builder outputSymbols = Assignments.builder();
        for (Symbol symbol2 : outerAggregationNode.getOutputSymbols()) {
            if (coalesceSymbols.containsKey(symbol2)) {
                Coalesce expression = new Coalesce(symbol2.toSymbolReference(), new Constant((Type)BigintType.BIGINT, 0L), new Expression[0]);
                outputSymbols.put((Symbol)coalesceSymbols.get(symbol2), expression);
                continue;
            }
            outputSymbols.putIdentity(symbol2);
        }
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), outerAggregationNode, outputSymbols.build()));
    }

    private static List<List<Symbol>> createGroups(List<Symbol> groupingKeys, Set<Symbol> nonDistinctAggregationArguments, boolean hasNonDistinctAggregation, Map<Symbol, Integer> distinctAggregationArgumentToGroupIdMap) {
        ImmutableList.Builder groups = ImmutableList.builder();
        if (hasNonDistinctAggregation) {
            groups.add((Object)ImmutableList.copyOf((Collection)ImmutableSet.builder().addAll(groupingKeys).addAll(nonDistinctAggregationArguments).build()));
        }
        distinctAggregationArgumentToGroupIdMap.entrySet().stream().sorted(Map.Entry.comparingByValue()).forEach(entry -> {
            Symbol distinctAggregationInput = (Symbol)entry.getKey();
            groups.add((Object)ImmutableList.copyOf((Collection)ImmutableSet.builder().addAll((Iterable)groupingKeys).add((Object)distinctAggregationInput).build()));
        });
        return groups.build();
    }
}

