/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableSet;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.Util;
import io.trino.sql.planner.optimizations.Cardinality;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import java.util.List;
import java.util.Set;

public class ReplaceRedundantJoinWithSource
implements Rule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join();

    @Override
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(JoinNode node, Captures captures, Rule.Context context) {
        Cardinality leftCardinality = QueryCardinalityUtil.extractCardinality(node.getLeft(), context.getLookup());
        if (leftCardinality.isEmpty()) {
            return Rule.Result.empty();
        }
        Cardinality rightCardinality = QueryCardinalityUtil.extractCardinality(node.getRight(), context.getLookup());
        if (rightCardinality.isEmpty()) {
            return Rule.Result.empty();
        }
        boolean leftSourceScalarWithNoOutputs = node.getLeft().getOutputSymbols().isEmpty() && leftCardinality.isScalar();
        boolean rightSourceScalarWithNoOutputs = node.getRight().getOutputSymbols().isEmpty() && rightCardinality.isScalar();
        return switch (node.getType()) {
            default -> throw new MatchException(null, null);
            case JoinNode.Type.INNER -> {
                List<Symbol> sourceOutputs;
                PlanNode source;
                if (leftSourceScalarWithNoOutputs) {
                    source = node.getRight();
                    sourceOutputs = node.getRightOutputSymbols();
                } else if (rightSourceScalarWithNoOutputs) {
                    source = node.getLeft();
                    sourceOutputs = node.getLeftOutputSymbols();
                } else {
                    yield Rule.Result.empty();
                }
                if (node.getFilter().isPresent()) {
                    source = new FilterNode(context.getIdAllocator().getNextId(), source, node.getFilter().get());
                }
                yield Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), source, (Set<Symbol>)ImmutableSet.copyOf(sourceOutputs)).orElse(source));
            }
            case JoinNode.Type.LEFT -> {
                if (rightSourceScalarWithNoOutputs) {
                    yield Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), node.getLeft(), (Set<Symbol>)ImmutableSet.copyOf(node.getLeftOutputSymbols())).orElse(node.getLeft()));
                }
                yield Rule.Result.empty();
            }
            case JoinNode.Type.RIGHT -> {
                if (leftSourceScalarWithNoOutputs) {
                    yield Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), node.getRight(), (Set<Symbol>)ImmutableSet.copyOf(node.getRightOutputSymbols())).orElse(node.getRight()));
                }
                yield Rule.Result.empty();
            }
            case JoinNode.Type.FULL -> leftSourceScalarWithNoOutputs && rightCardinality.isAtLeastScalar() ? Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), node.getRight(), (Set<Symbol>)ImmutableSet.copyOf(node.getRightOutputSymbols())).orElse(node.getRight())) : (rightSourceScalarWithNoOutputs && leftCardinality.isAtLeastScalar() ? Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), node.getLeft(), (Set<Symbol>)ImmutableSet.copyOf(node.getLeftOutputSymbols())).orElse(node.getLeft())) : Rule.Result.empty());
        };
    }
}

