/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.Util;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class PushPartialAggregationThroughJoin
implements Rule<AggregationNode> {
    private static final Capture<JoinNode> JOIN_NODE = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(PushPartialAggregationThroughJoin::isSupportedAggregationNode).with(Patterns.source().matching(Patterns.join().capturedAs(JOIN_NODE)));

    private static boolean isSupportedAggregationNode(AggregationNode aggregationNode) {
        if (aggregationNode.isStreamable()) {
            return false;
        }
        if (aggregationNode.getHashSymbol().isPresent()) {
            return false;
        }
        return aggregationNode.getStep() == AggregationNode.Step.PARTIAL && aggregationNode.getGroupingSetCount() == 1;
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isPushPartialAggregationThroughJoin(session);
    }

    @Override
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        JoinNode joinNode = (JoinNode)captures.get(JOIN_NODE);
        if (joinNode.getType() != JoinNode.Type.INNER) {
            return Rule.Result.empty();
        }
        if (PushPartialAggregationThroughJoin.allAggregationsOn(aggregationNode.getAggregations(), joinNode.getLeft().getOutputSymbols())) {
            return Rule.Result.ofPlanNode(this.pushPartialToLeftChild(aggregationNode, joinNode, context));
        }
        if (PushPartialAggregationThroughJoin.allAggregationsOn(aggregationNode.getAggregations(), joinNode.getRight().getOutputSymbols())) {
            return Rule.Result.ofPlanNode(this.pushPartialToRightChild(aggregationNode, joinNode, context));
        }
        return Rule.Result.empty();
    }

    private static boolean allAggregationsOn(Map<Symbol, AggregationNode.Aggregation> aggregations, List<Symbol> symbols) {
        Set inputs = (Set)aggregations.values().stream().map(SymbolsExtractor::extractAll).flatMap(Collection::stream).collect(ImmutableSet.toImmutableSet());
        return symbols.containsAll(inputs);
    }

    private PlanNode pushPartialToLeftChild(AggregationNode node, JoinNode child, Rule.Context context) {
        ImmutableSet joinLeftChildSymbols = ImmutableSet.copyOf(child.getLeft().getOutputSymbols());
        List<Symbol> groupingSet = this.getPushedDownGroupingSet(node, (Set<Symbol>)joinLeftChildSymbols, (Set<Symbol>)Sets.intersection(this.getJoinRequiredSymbols(child), (Set)joinLeftChildSymbols));
        AggregationNode pushedAggregation = this.replaceAggregationSource(node, child.getLeft(), groupingSet);
        return this.pushPartialToJoin(node, child, pushedAggregation, child.getRight(), context);
    }

    private PlanNode pushPartialToRightChild(AggregationNode node, JoinNode child, Rule.Context context) {
        ImmutableSet joinRightChildSymbols = ImmutableSet.copyOf(child.getRight().getOutputSymbols());
        List<Symbol> groupingSet = this.getPushedDownGroupingSet(node, (Set<Symbol>)joinRightChildSymbols, (Set<Symbol>)Sets.intersection(this.getJoinRequiredSymbols(child), (Set)joinRightChildSymbols));
        AggregationNode pushedAggregation = this.replaceAggregationSource(node, child.getRight(), groupingSet);
        return this.pushPartialToJoin(node, child, child.getLeft(), pushedAggregation, context);
    }

    private Set<Symbol> getJoinRequiredSymbols(JoinNode node) {
        return (Set)Streams.concat((Stream[])new Stream[]{node.getCriteria().stream().map(JoinNode.EquiJoinClause::getLeft), node.getCriteria().stream().map(JoinNode.EquiJoinClause::getRight), node.getFilter().map(SymbolsExtractor::extractUnique).orElse((Set)ImmutableSet.of()).stream(), node.getLeftHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream(), node.getRightHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream()}).collect(ImmutableSet.toImmutableSet());
    }

    private List<Symbol> getPushedDownGroupingSet(AggregationNode aggregation, Set<Symbol> availableSymbols, Set<Symbol> requiredJoinSymbols) {
        List<Symbol> groupingSet = aggregation.getGroupingKeys();
        List<Symbol> pushedDownGroupingSet = groupingSet.stream().filter(availableSymbols::contains).collect(Collectors.toList());
        HashSet existingSymbols = new HashSet(pushedDownGroupingSet);
        requiredJoinSymbols.stream().filter(existingSymbols::add).forEach(pushedDownGroupingSet::add);
        return pushedDownGroupingSet;
    }

    private AggregationNode replaceAggregationSource(AggregationNode aggregation, PlanNode source, List<Symbol> groupingKeys) {
        return AggregationNode.builderFrom(aggregation).setSource(source).setGroupingSets(AggregationNode.singleGroupingSet(groupingKeys)).setPreGroupedSymbols((List<Symbol>)ImmutableList.of()).build();
    }

    private PlanNode pushPartialToJoin(AggregationNode aggregation, JoinNode child, PlanNode leftChild, PlanNode rightChild, Rule.Context context) {
        JoinNode joinNode = new JoinNode(child.getId(), child.getType(), leftChild, rightChild, child.getCriteria(), leftChild.getOutputSymbols(), rightChild.getOutputSymbols(), child.isMaySkipOutputDuplicates(), child.getFilter(), child.getLeftHashSymbol(), child.getRightHashSymbol(), child.getDistributionType(), child.isSpillable(), child.getDynamicFilters(), child.getReorderJoinStatsAndCost());
        return Util.restrictOutputs(context.getIdAllocator(), joinNode, (Set<Symbol>)ImmutableSet.copyOf(aggregation.getOutputSymbols())).orElse(joinNode);
    }
}

