/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.AggregationDecorrelation;
import io.trino.sql.planner.optimizations.PlanNodeDecorrelator;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class TransformCorrelatedGlobalAggregationWithProjection
implements Rule<CorrelatedJoinNode> {
    private static final CatalogSchemaFunctionName BOOL_OR = GlobalFunctionCatalog.builtinFunctionName("bool_or");
    private static final Capture<ProjectNode> PROJECTION = Capture.newCapture();
    private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
    private static final Capture<PlanNode> SOURCE = Capture.newCapture();
    private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo((Object)Booleans.TRUE)).with(Patterns.CorrelatedJoin.subquery().matching(Patterns.project().capturedAs(PROJECTION).with(Patterns.source().matching(Patterns.aggregation().with(Pattern.empty(Patterns.Aggregation.groupingColumns())).with(Patterns.source().capturedAs(SOURCE)).capturedAs(AGGREGATION)))));
    private final PlannerContext plannerContext;

    public TransformCorrelatedGlobalAggregationWithProjection(PlannerContext plannerContext) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    @Override
    public Pattern<CorrelatedJoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
        Preconditions.checkArgument((correlatedJoinNode.getType() == JoinType.INNER || correlatedJoinNode.getType() == JoinType.LEFT ? 1 : 0) != 0, (String)"unexpected correlated join type: %s", (Object)((Object)correlatedJoinNode.getType()));
        PlanNode source = (PlanNode)captures.get(SOURCE);
        AggregationNode distinct = null;
        PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(this.plannerContext, context.getSymbolAllocator(), context.getLookup());
        Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
        if (decorrelatedSource.isEmpty()) {
            if (AggregationDecorrelation.isDistinctOperator(source)) {
                distinct = (AggregationNode)source;
                source = distinct.getSource();
                decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
            }
            if (decorrelatedSource.isEmpty()) {
                return Rule.Result.empty();
            }
        }
        source = decorrelatedSource.get().getNode();
        Optional<Object> nonNull = Optional.empty();
        AggregationNode globalAggregation = (AggregationNode)captures.get(AGGREGATION);
        if (!TransformCorrelatedGlobalAggregationWithProjection.isNullInsensitiveAggregation(globalAggregation)) {
            nonNull = Optional.of(context.getSymbolAllocator().newSymbol("non_null", (Type)BooleanType.BOOLEAN));
            source = new ProjectNode(context.getIdAllocator().getNextId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).put((Symbol)nonNull.get(), Booleans.TRUE).build());
        }
        AssignUniqueId inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", (Type)BigintType.BIGINT));
        JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinType.LEFT, inputWithUniqueId, source, (List<JoinNode.EquiJoinClause>)ImmutableList.of(), ((PlanNode)inputWithUniqueId).getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), (Map<DynamicFilterId, Symbol>)ImmutableMap.of(), Optional.empty());
        PlanNode root = join;
        if (distinct != null) {
            ImmutableList.Builder distinctSymbols = ImmutableList.builder().addAll(join.getLeftOutputSymbols()).addAll(distinct.getGroupingKeys());
            nonNull.ifPresent(arg_0 -> ((ImmutableList.Builder)distinctSymbols).add(arg_0));
            root = AggregationDecorrelation.restoreDistinctAggregation(distinct, join, (List<Symbol>)distinctSymbols.build());
        }
        Map<Symbol, AggregationNode.Aggregation> aggregations = globalAggregation.getAggregations();
        if (nonNull.isPresent()) {
            ImmutableMap.Builder masksBuilder = ImmutableMap.builder();
            Assignments.Builder assignmentsBuilder = Assignments.builder();
            for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : globalAggregation.getAggregations().entrySet()) {
                AggregationNode.Aggregation aggregation = entry.getValue();
                if (aggregation.getMask().isPresent()) {
                    Symbol newMask = context.getSymbolAllocator().newSymbol("mask", (Type)BooleanType.BOOLEAN);
                    Expression expression = IrUtils.and(aggregation.getMask().get().toSymbolReference(), ((Symbol)nonNull.get()).toSymbolReference());
                    assignmentsBuilder.put(newMask, expression);
                    masksBuilder.put((Object)entry.getKey(), (Object)newMask);
                    continue;
                }
                masksBuilder.put((Object)entry.getKey(), (Object)((Symbol)nonNull.get()));
            }
            Assignments maskAssignments = assignmentsBuilder.build();
            if (!maskAssignments.isEmpty()) {
                root = new ProjectNode(context.getIdAllocator().getNextId(), root, Assignments.builder().putIdentities(root.getOutputSymbols()).putAll(maskAssignments).build());
            }
            aggregations = AggregationDecorrelation.rewriteWithMasks(globalAggregation.getAggregations(), (Map<Symbol, Symbol>)masksBuilder.buildOrThrow());
        }
        globalAggregation = new AggregationNode(globalAggregation.getId(), root, aggregations, AggregationNode.singleGroupingSet((List<Symbol>)ImmutableList.builder().addAll(join.getLeftOutputSymbols()).addAll(globalAggregation.getGroupingKeys()).build()), (List<Symbol>)ImmutableList.of(), globalAggregation.getStep(), Optional.empty(), Optional.empty());
        HashSet<Symbol> outputSymbols = new HashSet<Symbol>(correlatedJoinNode.getOutputSymbols());
        List expectedAggregationOutputs = (List)globalAggregation.getOutputSymbols().stream().filter(outputSymbols::contains).collect(ImmutableList.toImmutableList());
        Assignments assignments = Assignments.builder().putIdentities(expectedAggregationOutputs).putAll(((ProjectNode)captures.get(PROJECTION)).getAssignments()).build();
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), globalAggregation, assignments));
    }

    private static boolean isNullInsensitiveAggregation(AggregationNode node) {
        if (node.getAggregations().size() != 1) {
            return false;
        }
        AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation)Iterables.getOnlyElement(node.getAggregations().values());
        if (aggregation.getFilter().isPresent() || aggregation.getMask().isPresent()) {
            return false;
        }
        return aggregation.getResolvedFunction().name().equals((Object)BOOL_OR) && aggregation.getArguments().getFirst() instanceof Reference;
    }
}

