/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.ResolvedFunction;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Lambda;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.PushProjectionThroughJoin;
import io.trino.sql.planner.iterative.rule.Util;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.assertj.core.util.VisibleForTesting;

public class PushPartialAggregationThroughJoin {
    private static boolean isSupportedAggregationNode(AggregationNode aggregationNode) {
        if (aggregationNode.isStreamable()) {
            return false;
        }
        if (aggregationNode.getHashSymbol().isPresent()) {
            return false;
        }
        return aggregationNode.getStep() == AggregationNode.Step.PARTIAL && aggregationNode.getGroupingSetCount() == 1;
    }

    public Iterable<Rule<?>> rules() {
        return ImmutableList.of(this.pushPartialAggregationThroughJoinWithoutProjection(), this.pushPartialAggregationThroughJoinWithProjection());
    }

    @VisibleForTesting
    Rule<?> pushPartialAggregationThroughJoinWithoutProjection() {
        return new PushPartialAggregationThroughJoinWithoutProjection();
    }

    @VisibleForTesting
    Rule<?> pushPartialAggregationThroughJoinWithProjection() {
        return new PushPartialAggregationThroughJoinWithProjection();
    }

    private Rule.Result applyPushdown(AggregationNode aggregationNode, Rule.Context context) {
        JoinNode joinNode = (JoinNode)context.getLookup().resolve(aggregationNode.getSource());
        if (joinNode.getType() != JoinType.INNER) {
            return Rule.Result.empty();
        }
        if (PushPartialAggregationThroughJoin.allAggregationsOn(aggregationNode.getAggregations(), joinNode.getLeft().getOutputSymbols())) {
            return this.pushPartialToLeftChild(aggregationNode, joinNode, context).map(Rule.Result::ofPlanNode).orElse(Rule.Result.empty());
        }
        if (PushPartialAggregationThroughJoin.allAggregationsOn(aggregationNode.getAggregations(), joinNode.getRight().getOutputSymbols())) {
            return this.pushPartialToRightChild(aggregationNode, joinNode, context).map(Rule.Result::ofPlanNode).orElse(Rule.Result.empty());
        }
        return Rule.Result.empty();
    }

    private static boolean allAggregationsOn(Map<Symbol, AggregationNode.Aggregation> aggregations, List<Symbol> symbols) {
        Set inputs = (Set)aggregations.values().stream().map(SymbolsExtractor::extractAll).flatMap(Collection::stream).collect(ImmutableSet.toImmutableSet());
        return symbols.containsAll(inputs);
    }

    private Optional<PlanNode> pushPartialToLeftChild(AggregationNode node, JoinNode child, Rule.Context context) {
        return this.getPushedAggregation(node, child, child.getLeft(), context).map(pushedAggregation -> this.replaceJoin(node, (AggregationNode)pushedAggregation, child, (PlanNode)pushedAggregation, child.getRight(), context));
    }

    private Optional<PlanNode> pushPartialToRightChild(AggregationNode node, JoinNode child, Rule.Context context) {
        return this.getPushedAggregation(node, child, child.getRight(), context).map(pushedAggregation -> this.replaceJoin(node, (AggregationNode)pushedAggregation, child, child.getLeft(), (PlanNode)pushedAggregation, context));
    }

    private Optional<AggregationNode> getPushedAggregation(AggregationNode node, JoinNode child, PlanNode joinSource, Rule.Context context) {
        ImmutableSet joinSourceSymbols = ImmutableSet.copyOf(joinSource.getOutputSymbols());
        List<Symbol> groupingSet = this.getPushedDownGroupingSet(node, (Set<Symbol>)joinSourceSymbols, (Set<Symbol>)Sets.intersection(this.getJoinRequiredSymbols(child), (Set)joinSourceSymbols));
        AggregationNode pushedAggregation = this.replaceAggregationSource(node, joinSource, groupingSet);
        if (this.skipPartialAggregationPushdown(child, node, pushedAggregation, context)) {
            return Optional.empty();
        }
        return Optional.of(pushedAggregation);
    }

    private boolean skipPartialAggregationPushdown(JoinNode join, AggregationNode originalAggregation, AggregationNode pushedAggregation, Rule.Context context) {
        PlanNodeStatsEstimate sourceStats = context.getStatsProvider().getStats(pushedAggregation.getSource());
        double sourceRowCount = sourceStats.getOutputRowCount();
        double joinRowCount = context.getStatsProvider().getStats(join).getOutputRowCount();
        if (Double.isNaN(sourceRowCount) || Double.isNaN(joinRowCount) || joinRowCount > 1.1 * sourceRowCount) {
            return true;
        }
        if (ImmutableSet.copyOf(originalAggregation.getGroupingKeys()).size() < ImmutableSet.copyOf(pushedAggregation.getGroupingKeys()).size()) {
            return true;
        }
        for (Symbol symbol : pushedAggregation.getGroupingKeys()) {
            double ndv = sourceStats.getSymbolStatistics(symbol).getDistinctValuesCount();
            if (!Double.isNaN(ndv) && !(ndv * 2.0 > sourceRowCount)) continue;
            return true;
        }
        return false;
    }

    private Set<Symbol> getJoinRequiredSymbols(JoinNode node) {
        return (Set)Streams.concat((Stream[])new Stream[]{node.getCriteria().stream().map(JoinNode.EquiJoinClause::getLeft), node.getCriteria().stream().map(JoinNode.EquiJoinClause::getRight), node.getFilter().map(SymbolsExtractor::extractUnique).orElse((Set)ImmutableSet.of()).stream(), node.getLeftHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream(), node.getRightHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream()}).collect(ImmutableSet.toImmutableSet());
    }

    private List<Symbol> getPushedDownGroupingSet(AggregationNode aggregation, Set<Symbol> availableSymbols, Set<Symbol> requiredJoinSymbols) {
        List<Symbol> groupingSet = aggregation.getGroupingKeys();
        List<Symbol> pushedDownGroupingSet = groupingSet.stream().filter(availableSymbols::contains).collect(Collectors.toList());
        HashSet existingSymbols = new HashSet(pushedDownGroupingSet);
        requiredJoinSymbols.stream().filter(existingSymbols::add).forEach(pushedDownGroupingSet::add);
        return pushedDownGroupingSet;
    }

    private AggregationNode replaceAggregationSource(AggregationNode aggregation, PlanNode source, List<Symbol> groupingKeys) {
        return AggregationNode.builderFrom(aggregation).setSource(source).setGroupingSets(AggregationNode.singleGroupingSet(groupingKeys)).setPreGroupedSymbols((List<Symbol>)ImmutableList.of()).setIsInputReducingAggregation(false).build();
    }

    private PlanNode replaceJoin(AggregationNode aggregation, AggregationNode pushedAggregation, JoinNode child, PlanNode leftChild, PlanNode rightChild, Rule.Context context) {
        JoinNode joinNode = new JoinNode(child.getId(), child.getType(), leftChild, rightChild, child.getCriteria(), leftChild.getOutputSymbols(), rightChild.getOutputSymbols(), child.isMaySkipOutputDuplicates(), child.getFilter(), child.getLeftHashSymbol(), child.getRightHashSymbol(), child.getDistributionType(), child.isSpillable(), child.getDynamicFilters(), child.getReorderJoinStatsAndCost());
        PlanNode result = Util.restrictOutputs(context.getIdAllocator(), joinNode, (Set<Symbol>)ImmutableSet.copyOf(aggregation.getOutputSymbols())).orElse(joinNode);
        if (aggregation.isInputReducingAggregation() && !ImmutableSet.copyOf(aggregation.getGroupingKeys()).containsAll(pushedAggregation.getGroupingKeys())) {
            result = this.toIntermediateAggregation(aggregation, result, context);
        }
        return result;
    }

    private PlanNode toIntermediateAggregation(AggregationNode partialAggregation, PlanNode source, Rule.Context context) {
        HashMap<Symbol, AggregationNode.Aggregation> intermediateAggregation = new HashMap<Symbol, AggregationNode.Aggregation>();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : partialAggregation.getAggregations().entrySet()) {
            AggregationNode.Aggregation aggregation = entry.getValue();
            ResolvedFunction resolvedFunction = aggregation.getResolvedFunction();
            intermediateAggregation.put(entry.getKey(), new AggregationNode.Aggregation(resolvedFunction, (List<Expression>)ImmutableList.builder().add((Object)entry.getKey().toSymbolReference()).addAll((Iterable)aggregation.getArguments().stream().filter(Lambda.class::isInstance).collect(ImmutableList.toImmutableList())).build(), false, Optional.empty(), Optional.empty(), Optional.empty()));
        }
        return new AggregationNode(context.getIdAllocator().getNextId(), source, intermediateAggregation, partialAggregation.getGroupingSets(), (List<Symbol>)ImmutableList.of(), AggregationNode.Step.INTERMEDIATE, Optional.empty(), partialAggregation.getGroupIdSymbol());
    }

    private class PushPartialAggregationThroughJoinWithoutProjection
    implements Rule<AggregationNode> {
        private static final Pattern<AggregationNode> PATTERN_WITHOUT_PROJECTION = Patterns.aggregation().matching(PushPartialAggregationThroughJoin::isSupportedAggregationNode).with(Patterns.source().matching(Patterns.join()));

        private PushPartialAggregationThroughJoinWithoutProjection() {
        }

        @Override
        public Pattern<AggregationNode> getPattern() {
            return PATTERN_WITHOUT_PROJECTION;
        }

        @Override
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.isPushPartialAggregationThroughJoin(session);
        }

        @Override
        public Rule.Result apply(AggregationNode node, Captures captures, Rule.Context context) {
            return PushPartialAggregationThroughJoin.this.applyPushdown(node, context);
        }
    }

    private class PushPartialAggregationThroughJoinWithProjection
    implements Rule<AggregationNode> {
        private static final Pattern<AggregationNode> PATTERN_WITH_PROJECTION = Patterns.aggregation().matching(PushPartialAggregationThroughJoin::isSupportedAggregationNode).with(Patterns.source().matching(Patterns.project().with(Patterns.source().matching(Patterns.join()))));

        private PushPartialAggregationThroughJoinWithProjection() {
        }

        @Override
        public Pattern<AggregationNode> getPattern() {
            return PATTERN_WITH_PROJECTION;
        }

        @Override
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.isPushPartialAggregationThroughJoin(session);
        }

        @Override
        public Rule.Result apply(AggregationNode node, Captures captures, Rule.Context context) {
            ProjectNode projectNode = (ProjectNode)context.getLookup().resolve(node.getSource());
            Optional<PlanNode> joinNodeOptional = PushProjectionThroughJoin.pushProjectionThroughJoin(projectNode, context.getLookup(), context.getIdAllocator());
            if (joinNodeOptional.isEmpty()) {
                return Rule.Result.empty();
            }
            return PushPartialAggregationThroughJoin.this.applyPushdown((AggregationNode)node.replaceChildren((List<PlanNode>)ImmutableList.of((Object)joinNodeOptional.get())), context);
        }
    }
}

