/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Streams;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.Util;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.tree.BooleanLiteral;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

public class DecorrelateLeftUnnestWithGlobalAggregation
implements Rule<CorrelatedJoinNode> {
    private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo((Object)BooleanLiteral.TRUE_LITERAL)).matching(node -> node.getType() == JoinType.INNER || node.getType() == JoinType.LEFT);

    @Override
    public Pattern<CorrelatedJoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
        Optional globalAggregation = PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(DecorrelateLeftUnnestWithGlobalAggregation::isGlobalAggregation).recurseOnlyWhen(node -> node instanceof ProjectNode || DecorrelateLeftUnnestWithGlobalAggregation.isGroupedAggregation(node)).findFirst();
        if (globalAggregation.isEmpty()) {
            return Rule.Result.empty();
        }
        Optional subqueryUnnest = PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(node -> DecorrelateLeftUnnestWithGlobalAggregation.isSupportedUnnest(node, correlatedJoinNode.getCorrelation(), context.getLookup())).recurseOnlyWhen(node -> node instanceof ProjectNode || DecorrelateLeftUnnestWithGlobalAggregation.isGlobalAggregation(node) || DecorrelateLeftUnnestWithGlobalAggregation.isGroupedAggregation(node)).findFirst();
        if (subqueryUnnest.isEmpty()) {
            return Rule.Result.empty();
        }
        UnnestNode unnestNode = (UnnestNode)subqueryUnnest.get();
        PlanNode input = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", (Type)BigintType.BIGINT));
        PlanNode unnestSource = context.getLookup().resolve(unnestNode.getSource());
        if (unnestSource instanceof ProjectNode) {
            ProjectNode sourceProjection = (ProjectNode)unnestSource;
            input = new ProjectNode(sourceProjection.getId(), input, Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(sourceProjection.getAssignments()).build());
        }
        UnnestNode rewrittenUnnest = new UnnestNode(context.getIdAllocator().getNextId(), input, input.getOutputSymbols(), unnestNode.getMappings(), unnestNode.getOrdinalitySymbol(), JoinType.LEFT);
        PlanNode result = DecorrelateLeftUnnestWithGlobalAggregation.rewriteNodeSequence(context.getLookup().resolve(correlatedJoinNode.getSubquery()), input.getOutputSymbols(), rewrittenUnnest, unnestNode.getId(), context.getLookup());
        return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), result, (Set<Symbol>)ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(result));
    }

    private static boolean isGlobalAggregation(PlanNode node) {
        if (!(node instanceof AggregationNode)) {
            return false;
        }
        AggregationNode aggregationNode = (AggregationNode)node;
        return aggregationNode.hasSingleGlobalAggregation() && aggregationNode.getStep() == AggregationNode.Step.SINGLE;
    }

    private static boolean isGroupedAggregation(PlanNode node) {
        if (!(node instanceof AggregationNode)) {
            return false;
        }
        AggregationNode aggregationNode = (AggregationNode)node;
        return aggregationNode.hasNonEmptyGroupingSet() && aggregationNode.getGroupingSetCount() == 1 && aggregationNode.getStep() == AggregationNode.Step.SINGLE;
    }

    private static boolean isSupportedUnnest(PlanNode node, List<Symbol> correlation, Lookup lookup) {
        if (!(node instanceof UnnestNode)) {
            return false;
        }
        UnnestNode unnestNode = (UnnestNode)node;
        List unnestSymbols = (List)unnestNode.getMappings().stream().map(UnnestNode.Mapping::getInput).collect(ImmutableList.toImmutableList());
        PlanNode unnestSource = lookup.resolve(unnestNode.getSource());
        boolean basedOnCorrelation = ImmutableSet.copyOf(correlation).containsAll((Collection)unnestSymbols) || unnestSource instanceof ProjectNode && ImmutableSet.copyOf(correlation).containsAll(SymbolsExtractor.extractUnique(((ProjectNode)unnestSource).getAssignments().getExpressions()));
        return QueryCardinalityUtil.isScalar(unnestNode.getSource(), lookup) && unnestNode.getReplicateSymbols().isEmpty() && basedOnCorrelation && unnestNode.getJoinType() == JoinType.LEFT;
    }

    private static PlanNode rewriteNodeSequence(PlanNode root, List<Symbol> leftOutputs, PlanNode sequenceSource, PlanNodeId correlatedUnnestId, Lookup lookup) {
        if (root.getId().equals(correlatedUnnestId)) {
            return sequenceSource;
        }
        PlanNode source = DecorrelateLeftUnnestWithGlobalAggregation.rewriteNodeSequence(lookup.resolve((PlanNode)Iterables.getOnlyElement(root.getSources())), leftOutputs, sequenceSource, correlatedUnnestId, lookup);
        if (root instanceof AggregationNode) {
            AggregationNode aggregationNode = (AggregationNode)root;
            return DecorrelateLeftUnnestWithGlobalAggregation.withGrouping(aggregationNode, leftOutputs, source);
        }
        if (root instanceof ProjectNode) {
            ProjectNode projectNode = (ProjectNode)root;
            return new ProjectNode(projectNode.getId(), source, Assignments.builder().putAll(projectNode.getAssignments()).putIdentities(leftOutputs).build());
        }
        throw new IllegalStateException("unexpected node: " + String.valueOf(root));
    }

    private static AggregationNode withGrouping(AggregationNode aggregationNode, List<Symbol> groupingSymbols, PlanNode source) {
        AggregationNode.GroupingSetDescriptor groupingSet = AggregationNode.singleGroupingSet((List)Streams.concat((Stream[])new Stream[]{groupingSymbols.stream(), aggregationNode.getGroupingKeys().stream()}).distinct().collect(ImmutableList.toImmutableList()));
        return AggregationNode.singleAggregation(aggregationNode.getId(), source, aggregationNode.getAggregations(), groupingSet);
    }
}

