/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.iterative.rule.AggregationDecorrelation;
import io.prestosql.sql.planner.iterative.rule.Util;
import io.prestosql.sql.planner.optimizations.PlanNodeDecorrelator;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.AssignUniqueId;
import io.prestosql.sql.planner.plan.CorrelatedJoinNode;
import io.prestosql.sql.planner.plan.DynamicFilterId;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.tree.BooleanLiteral;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class TransformCorrelatedDistinctAggregationWithoutProjection
implements Rule<CorrelatedJoinNode> {
    private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
    private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Patterns.CorrelatedJoin.type().equalTo((Object)CorrelatedJoinNode.Type.LEFT)).with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo((Object)BooleanLiteral.TRUE_LITERAL)).with(Patterns.CorrelatedJoin.subquery().matching(Patterns.aggregation().matching(AggregationDecorrelation::isDistinctOperator).capturedAs(AGGREGATION)));
    private final Metadata metadata;

    public TransformCorrelatedDistinctAggregationWithoutProjection(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override
    public Pattern<CorrelatedJoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
        PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(this.metadata, context.getSymbolAllocator(), context.getLookup());
        Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(((AggregationNode)captures.get(AGGREGATION)).getSource(), correlatedJoinNode.getCorrelation());
        if (decorrelatedSource.isEmpty()) {
            return Rule.Result.empty();
        }
        PlanNode source = decorrelatedSource.get().getNode();
        AssignUniqueId inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", (Type)BigintType.BIGINT));
        JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, inputWithUniqueId, source, (List<JoinNode.EquiJoinClause>)ImmutableList.of(), ((PlanNode)inputWithUniqueId).getOutputSymbols(), source.getOutputSymbols(), decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), (Map<DynamicFilterId, Symbol>)ImmutableMap.of(), Optional.empty());
        AggregationNode aggregation = (AggregationNode)captures.get(AGGREGATION);
        aggregation = new AggregationNode(aggregation.getId(), join, aggregation.getAggregations(), AggregationNode.singleGroupingSet((List<Symbol>)ImmutableList.builder().addAll(join.getLeftOutputSymbols()).addAll(aggregation.getGroupingKeys()).build()), (List<Symbol>)ImmutableList.of(), aggregation.getStep(), Optional.empty(), Optional.empty());
        Optional<PlanNode> project = Util.restrictOutputs(context.getIdAllocator(), aggregation, (Set<Symbol>)ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols()));
        return Rule.Result.ofPlanNode(project.orElse(aggregation));
    }
}

