/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import io.trino.SessionTestUtils;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TestingPlannerContext;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.ArrayList;
import java.util.List;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestPlanNodeSearcher {
    private static final PlanBuilder BUILDER = new PlanBuilder(new PlanNodeIdAllocator(), TestingPlannerContext.PLANNER_CONTEXT, SessionTestUtils.TEST_SESSION);

    @Test
    public void testFindAll() {
        int size = 10;
        ProjectNode root = BUILDER.project(Assignments.of(), (PlanNode)BUILDER.values(new Symbol[0]));
        for (int i = 1; i < size; ++i) {
            root = BUILDER.project(Assignments.of(), (PlanNode)root);
        }
        ArrayList<PlanNodeId> rootToBottomIds = new ArrayList<PlanNodeId>();
        ProjectNode node = root;
        while (node instanceof ProjectNode) {
            rootToBottomIds.add(node.getId());
            node = node.getSource();
        }
        List findAllResult = (List)PlanNodeSearcher.searchFrom((PlanNode)root).where(ProjectNode.class::isInstance).findAll().stream().map(PlanNode::getId).collect(ImmutableList.toImmutableList());
        Assertions.assertThat(rootToBottomIds).isEqualTo((Object)findAllResult);
    }

    @Test
    public void testFindAllMultipleSources() {
        ArrayList<JoinNode> joins = new ArrayList<JoinNode>();
        for (int i = 0; i < 4; ++i) {
            joins.add(BUILDER.join(JoinNode.Type.INNER, (PlanNode)BUILDER.values(new Symbol[0]), (PlanNode)BUILDER.values(new Symbol[0]), new JoinNode.EquiJoinClause[0]));
        }
        JoinNode leftSource = BUILDER.join(JoinNode.Type.INNER, (PlanNode)joins.get(0), (PlanNode)joins.get(1), new JoinNode.EquiJoinClause[0]);
        JoinNode rightSource = BUILDER.join(JoinNode.Type.INNER, (PlanNode)joins.get(2), (PlanNode)joins.get(3), new JoinNode.EquiJoinClause[0]);
        JoinNode root = BUILDER.join(JoinNode.Type.INNER, (PlanNode)leftSource, (PlanNode)rightSource, new JoinNode.EquiJoinClause[0]);
        ImmutableList.Builder idsInPreOrder = ImmutableList.builder();
        TestPlanNodeSearcher.joinNodePreorder((PlanNode)root, (ImmutableList.Builder<PlanNodeId>)idsInPreOrder);
        List findAllResult = (List)PlanNodeSearcher.searchFrom((PlanNode)root).where(JoinNode.class::isInstance).findAll().stream().map(PlanNode::getId).collect(ImmutableList.toImmutableList());
        Assertions.assertThat((List)idsInPreOrder.build()).isEqualTo((Object)findAllResult);
    }

    private static void joinNodePreorder(PlanNode root, ImmutableList.Builder<PlanNodeId> builder) {
        if (root instanceof ValuesNode) {
            return;
        }
        if (root instanceof JoinNode) {
            JoinNode join = (JoinNode)root;
            builder.add((Object)root.getId());
            TestPlanNodeSearcher.joinNodePreorder(join.getLeft(), builder);
            TestPlanNodeSearcher.joinNodePreorder(join.getRight(), builder);
            return;
        }
        throw new IllegalArgumentException("unsupported node type: " + root.getClass().getSimpleName());
    }
}

