/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Streams;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.AggregationDecorrelation;
import io.trino.sql.planner.iterative.rule.Util;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.IsNotNullPredicate;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

public class DecorrelateInnerUnnestWithGlobalAggregation
implements Rule<CorrelatedJoinNode> {
    private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo((Object)BooleanLiteral.TRUE_LITERAL)).matching(node -> node.getType() == CorrelatedJoinNode.Type.INNER || node.getType() == CorrelatedJoinNode.Type.LEFT);

    @Override
    public Pattern<CorrelatedJoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
        List globalAggregations = PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(DecorrelateInnerUnnestWithGlobalAggregation::isGlobalAggregation).recurseOnlyWhen(node -> node instanceof ProjectNode || DecorrelateInnerUnnestWithGlobalAggregation.isGlobalAggregation(node)).findAll();
        if (globalAggregations.isEmpty()) {
            return Rule.Result.empty();
        }
        AggregationNode reducingAggregation = (AggregationNode)globalAggregations.get(globalAggregations.size() - 1);
        Optional subqueryUnnest = PlanNodeSearcher.searchFrom(reducingAggregation.getSource(), context.getLookup()).where(node -> DecorrelateInnerUnnestWithGlobalAggregation.isSupportedUnnest(node, correlatedJoinNode.getCorrelation(), context.getLookup())).recurseOnlyWhen(node -> node instanceof ProjectNode || DecorrelateInnerUnnestWithGlobalAggregation.isGroupedAggregation(node)).findFirst();
        if (subqueryUnnest.isEmpty()) {
            return Rule.Result.empty();
        }
        UnnestNode unnestNode = (UnnestNode)subqueryUnnest.get();
        PlanNode input = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", (Type)BigintType.BIGINT));
        PlanNode unnestSource = context.getLookup().resolve(unnestNode.getSource());
        if (unnestSource instanceof ProjectNode) {
            ProjectNode sourceProjection = (ProjectNode)unnestSource;
            input = new ProjectNode(sourceProjection.getId(), input, Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(sourceProjection.getAssignments()).build());
        }
        Symbol ordinalitySymbol = unnestNode.getOrdinalitySymbol().orElseGet(() -> context.getSymbolAllocator().newSymbol("ordinality", (Type)BigintType.BIGINT));
        UnnestNode rewrittenUnnest = new UnnestNode(context.getIdAllocator().getNextId(), input, input.getOutputSymbols(), unnestNode.getMappings(), Optional.of(ordinalitySymbol), JoinNode.Type.LEFT, Optional.empty());
        Symbol mask = context.getSymbolAllocator().newSymbol("mask", (Type)BooleanType.BOOLEAN);
        ProjectNode sourceWithMask = new ProjectNode(context.getIdAllocator().getNextId(), rewrittenUnnest, Assignments.builder().putIdentities(rewrittenUnnest.getOutputSymbols()).put(mask, (Expression)new IsNotNullPredicate((Expression)ordinalitySymbol.toSymbolReference())).build());
        PlanNode result = DecorrelateInnerUnnestWithGlobalAggregation.rewriteNodeSequence(context.getLookup().resolve(correlatedJoinNode.getSubquery()), input.getOutputSymbols(), mask, sourceWithMask, reducingAggregation.getId(), unnestNode.getId(), context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup());
        return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), result, (Set<Symbol>)ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(result));
    }

    private static boolean isGlobalAggregation(PlanNode node) {
        if (!(node instanceof AggregationNode)) {
            return false;
        }
        AggregationNode aggregationNode = (AggregationNode)node;
        return aggregationNode.hasEmptyGroupingSet() && aggregationNode.getGroupingSetCount() == 1 && aggregationNode.getStep() == AggregationNode.Step.SINGLE;
    }

    private static boolean isGroupedAggregation(PlanNode node) {
        if (!(node instanceof AggregationNode)) {
            return false;
        }
        AggregationNode aggregationNode = (AggregationNode)node;
        return aggregationNode.hasNonEmptyGroupingSet() && aggregationNode.getGroupingSetCount() == 1 && aggregationNode.getStep() == AggregationNode.Step.SINGLE;
    }

    private static boolean isSupportedUnnest(PlanNode node, List<Symbol> correlation, Lookup lookup) {
        if (!(node instanceof UnnestNode)) {
            return false;
        }
        UnnestNode unnestNode = (UnnestNode)node;
        List unnestSymbols = (List)unnestNode.getMappings().stream().map(UnnestNode.Mapping::getInput).collect(ImmutableList.toImmutableList());
        PlanNode unnestSource = lookup.resolve(unnestNode.getSource());
        ImmutableSet correlationSymbols = ImmutableSet.copyOf(correlation);
        boolean basedOnCorrelation = correlationSymbols.containsAll(unnestSymbols) || unnestSource instanceof ProjectNode && correlationSymbols.containsAll(SymbolsExtractor.extractUnique(((ProjectNode)unnestSource).getAssignments().getExpressions()));
        return QueryCardinalityUtil.isScalar(unnestNode.getSource(), lookup) && unnestNode.getReplicateSymbols().isEmpty() && basedOnCorrelation && unnestNode.getJoinType() == JoinNode.Type.INNER && (unnestNode.getFilter().isEmpty() || unnestNode.getFilter().get().equals((Object)BooleanLiteral.TRUE_LITERAL));
    }

    private static PlanNode rewriteNodeSequence(PlanNode root, List<Symbol> leftOutputs, Symbol mask, PlanNode sequenceSource, PlanNodeId reducingAggregationId, PlanNodeId correlatedUnnestId, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) {
        if (root.getId().equals(correlatedUnnestId)) {
            return sequenceSource;
        }
        PlanNode source = DecorrelateInnerUnnestWithGlobalAggregation.rewriteNodeSequence(lookup.resolve((PlanNode)Iterables.getOnlyElement(root.getSources())), leftOutputs, mask, sequenceSource, reducingAggregationId, correlatedUnnestId, symbolAllocator, idAllocator, lookup);
        if (DecorrelateInnerUnnestWithGlobalAggregation.isGlobalAggregation(root)) {
            AggregationNode aggregationNode = (AggregationNode)root;
            if (aggregationNode.getId().equals(reducingAggregationId)) {
                return DecorrelateInnerUnnestWithGlobalAggregation.withGroupingAndMask(aggregationNode, leftOutputs, mask, source, symbolAllocator, idAllocator);
            }
            return DecorrelateInnerUnnestWithGlobalAggregation.withGrouping(aggregationNode, leftOutputs, source);
        }
        if (DecorrelateInnerUnnestWithGlobalAggregation.isGroupedAggregation(root)) {
            AggregationNode aggregationNode = (AggregationNode)root;
            return DecorrelateInnerUnnestWithGlobalAggregation.withGrouping(aggregationNode, (List<Symbol>)ImmutableList.builder().addAll(leftOutputs).add((Object)mask).build(), source);
        }
        if (root instanceof ProjectNode) {
            ProjectNode projectNode = (ProjectNode)root;
            return new ProjectNode(projectNode.getId(), source, Assignments.builder().putAll(projectNode.getAssignments()).putIdentities(leftOutputs).putIdentities((Iterable<Symbol>)(ImmutableSet.copyOf(source.getOutputSymbols()).contains((Object)mask) ? ImmutableList.of((Object)mask) : ImmutableList.of())).build());
        }
        throw new IllegalStateException("unexpected node: " + root);
    }

    private static AggregationNode withGroupingAndMask(AggregationNode aggregationNode, List<Symbol> groupingSymbols, Symbol mask, PlanNode source, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
        ImmutableMap.Builder masks = ImmutableMap.builder();
        Assignments.Builder assignmentsBuilder = Assignments.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
            AggregationNode.Aggregation aggregation = entry.getValue();
            if (aggregation.getMask().isPresent()) {
                Symbol newMask = symbolAllocator.newSymbol("mask", (Type)BooleanType.BOOLEAN);
                Expression expression = ExpressionUtils.and(new Expression[]{aggregation.getMask().get().toSymbolReference(), mask.toSymbolReference()});
                assignmentsBuilder.put(newMask, expression);
                masks.put((Object)entry.getKey(), (Object)newMask);
                continue;
            }
            masks.put((Object)entry.getKey(), (Object)mask);
        }
        Assignments maskAssignments = assignmentsBuilder.build();
        if (!maskAssignments.isEmpty()) {
            source = new ProjectNode(idAllocator.getNextId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).putAll(maskAssignments).build());
        }
        return new AggregationNode(aggregationNode.getId(), source, AggregationDecorrelation.rewriteWithMasks(aggregationNode.getAggregations(), (Map<Symbol, Symbol>)masks.buildOrThrow()), AggregationNode.singleGroupingSet(groupingSymbols), (List<Symbol>)ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
    }

    private static AggregationNode withGrouping(AggregationNode aggregationNode, List<Symbol> groupingSymbols, PlanNode source) {
        AggregationNode.GroupingSetDescriptor groupingSet = AggregationNode.singleGroupingSet((List)Streams.concat((Stream[])new Stream[]{groupingSymbols.stream(), aggregationNode.getGroupingKeys().stream()}).distinct().collect(ImmutableList.toImmutableList()));
        return new AggregationNode(aggregationNode.getId(), source, aggregationNode.getAggregations(), groupingSet, (List<Symbol>)ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
    }
}

