/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.Plugin;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.ExpectedValueProvider;
import io.trino.sql.planner.assertions.ExpressionMatcher;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.GroupReference;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.EliminateCrossJoins;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.optimizations.joins.JoinGraph;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestEliminateCrossJoins
extends BaseRuleTest {
    private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution();
    private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, (List<? extends Type>)ImmutableList.of((Object)BigintType.BIGINT, (Object)BigintType.BIGINT));
    private static final ResolvedFunction NEGATION_BIGINT = FUNCTIONS.resolveOperator(OperatorType.NEGATION, (List<? extends Type>)ImmutableList.of((Object)BigintType.BIGINT));
    private final PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();

    public TestEliminateCrossJoins() {
        super(new Plugin[0]);
    }

    @Test
    public void testEliminateCrossJoin() {
        this.tester().assertThat((Rule<?>)new EliminateCrossJoins()).setSystemProperty("join_reordering_strategy", "ELIMINATE_CROSS_JOINS").on(this.crossJoinAndJoin(JoinType.INNER)).matches(PlanMatchPattern.join(JoinType.INNER, builder -> builder.equiCriteria((List<ExpectedValueProvider<JoinNode.EquiJoinClause>>)ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol((Type)BigintType.BIGINT, "cySymbol"), new Symbol((Type)BigintType.BIGINT, "bySymbol")))).left(PlanMatchPattern.join(JoinType.INNER, leftJoinBuilder -> leftJoinBuilder.equiCriteria((List<ExpectedValueProvider<JoinNode.EquiJoinClause>>)ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol((Type)BigintType.BIGINT, "axSymbol"), new Symbol((Type)BigintType.BIGINT, "cxSymbol")))).left(PlanMatchPattern.any(new PlanMatchPattern[0])).right(PlanMatchPattern.any(new PlanMatchPattern[0])))).right(PlanMatchPattern.any(new PlanMatchPattern[0]))));
    }

    @Test
    public void testRetainOutgoingGroupReferences() {
        this.tester().assertThat((Rule<?>)new EliminateCrossJoins()).setSystemProperty("join_reordering_strategy", "ELIMINATE_CROSS_JOINS").on(this.crossJoinAndJoin(JoinType.INNER)).matches(PlanMatchPattern.node(JoinNode.class, PlanMatchPattern.node(JoinNode.class, PlanMatchPattern.node(GroupReference.class, new PlanMatchPattern[0]), PlanMatchPattern.node(GroupReference.class, new PlanMatchPattern[0])), PlanMatchPattern.node(GroupReference.class, new PlanMatchPattern[0])));
    }

    @Test
    public void testDoNotReorderOuterJoin() {
        this.tester().assertThat((Rule<?>)new EliminateCrossJoins()).setSystemProperty("join_reordering_strategy", "ELIMINATE_CROSS_JOINS").on(this.crossJoinAndJoin(JoinType.LEFT)).doesNotFire();
    }

    @Test
    public void testIsOriginalOrder() {
        Assertions.assertThat((boolean)EliminateCrossJoins.isOriginalOrder((List)ImmutableList.of((Object)0, (Object)1, (Object)2, (Object)3, (Object)4))).isTrue();
        Assertions.assertThat((boolean)EliminateCrossJoins.isOriginalOrder((List)ImmutableList.of((Object)0, (Object)2, (Object)1, (Object)3, (Object)4))).isFalse();
    }

    @Test
    public void testJoinOrder() {
        JoinNode plan = this.joinNode((PlanNode)this.joinNode((PlanNode)this.values("a"), (PlanNode)this.values("b"), new String[0]), (PlanNode)this.values("c"), "a", "c", "b", "c");
        JoinGraph joinGraph = JoinGraph.buildFrom((PlanNode)plan, (Lookup)Lookup.noLookup(), (PlanNodeIdAllocator)new PlanNodeIdAllocator());
        Assertions.assertThat((List)EliminateCrossJoins.getJoinOrder((JoinGraph)joinGraph)).isEqualTo((Object)ImmutableList.of((Object)0, (Object)2, (Object)1));
    }

    @Test
    public void testJoinOrderWithRealCrossJoin() {
        JoinNode leftPlan = this.joinNode((PlanNode)this.joinNode((PlanNode)this.values("a"), (PlanNode)this.values("b"), new String[0]), (PlanNode)this.values("c"), "a", "c", "b", "c");
        JoinNode rightPlan = this.joinNode((PlanNode)this.joinNode((PlanNode)this.values("x"), (PlanNode)this.values("y"), new String[0]), (PlanNode)this.values("z"), "x", "z", "y", "z");
        JoinNode plan = this.joinNode((PlanNode)leftPlan, (PlanNode)rightPlan, new String[0]);
        JoinGraph joinGraph = JoinGraph.buildFrom((PlanNode)plan, (Lookup)Lookup.noLookup(), (PlanNodeIdAllocator)new PlanNodeIdAllocator());
        Assertions.assertThat((List)EliminateCrossJoins.getJoinOrder((JoinGraph)joinGraph)).isEqualTo((Object)ImmutableList.of((Object)0, (Object)2, (Object)1, (Object)3, (Object)5, (Object)4));
    }

    @Test
    public void testJoinOrderWithMultipleEdgesBetweenNodes() {
        JoinNode plan = this.joinNode((PlanNode)this.joinNode((PlanNode)this.values("a"), (PlanNode)this.values("b1", "b2"), new String[0]), (PlanNode)this.values("c1", "c2"), "a", "c1", "b1", "c1", "b2", "c2");
        JoinGraph joinGraph = JoinGraph.buildFrom((PlanNode)plan, (Lookup)Lookup.noLookup(), (PlanNodeIdAllocator)new PlanNodeIdAllocator());
        Assertions.assertThat((List)EliminateCrossJoins.getJoinOrder((JoinGraph)joinGraph)).isEqualTo((Object)ImmutableList.of((Object)0, (Object)2, (Object)1));
    }

    @Test
    public void testDoesNotChangeOrderWithoutCrossJoin() {
        JoinNode plan = this.joinNode((PlanNode)this.joinNode((PlanNode)this.values("a"), (PlanNode)this.values("b"), "a", "b"), (PlanNode)this.values("c"), "b", "c");
        JoinGraph joinGraph = JoinGraph.buildFrom((PlanNode)plan, (Lookup)Lookup.noLookup(), (PlanNodeIdAllocator)new PlanNodeIdAllocator());
        Assertions.assertThat((List)EliminateCrossJoins.getJoinOrder((JoinGraph)joinGraph)).isEqualTo((Object)ImmutableList.of((Object)0, (Object)1, (Object)2));
    }

    @Test
    public void testDoNotReorderCrossJoins() {
        JoinNode plan = this.joinNode((PlanNode)this.joinNode((PlanNode)this.values("a"), (PlanNode)this.values("b"), new String[0]), (PlanNode)this.values("c"), "b", "c");
        JoinGraph joinGraph = JoinGraph.buildFrom((PlanNode)plan, (Lookup)Lookup.noLookup(), (PlanNodeIdAllocator)new PlanNodeIdAllocator());
        Assertions.assertThat((List)EliminateCrossJoins.getJoinOrder((JoinGraph)joinGraph)).isEqualTo((Object)ImmutableList.of((Object)0, (Object)1, (Object)2));
    }

    @Test
    public void testEliminateCrossJoinWithNonIdentityProjections() {
        this.tester().assertThat((Rule<?>)new EliminateCrossJoins()).setSystemProperty("join_reordering_strategy", "ELIMINATE_CROSS_JOINS").on(p -> {
            Symbol a1 = p.symbol("a1");
            Symbol a2 = p.symbol("a2");
            Symbol b = p.symbol("b");
            Symbol c = p.symbol("c");
            Symbol d = p.symbol("d");
            Symbol e = p.symbol("e");
            Symbol f = p.symbol("f");
            return p.join(JoinType.INNER, (PlanNode)p.project(Assignments.of((Symbol)a2, (Expression)new Call(NEGATION_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "a1"))), (Symbol)f, (Expression)new Reference((Type)BigintType.BIGINT, "f")), (PlanNode)p.join(JoinType.INNER, (PlanNode)p.project(Assignments.of((Symbol)a1, (Expression)new Reference((Type)BigintType.BIGINT, "a1"), (Symbol)f, (Expression)new Call(NEGATION_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "b")))), (PlanNode)p.join(JoinType.INNER, (PlanNode)p.values(a1), (PlanNode)p.values(b), new JoinNode.EquiJoinClause[0])), (PlanNode)p.values(e), new JoinNode.EquiJoinClause(a1, e))), (PlanNode)p.values(c, d), new JoinNode.EquiJoinClause(a2, c), new JoinNode.EquiJoinClause(f, d));
        }).matches(PlanMatchPattern.node(ProjectNode.class, PlanMatchPattern.join(JoinType.INNER, builder -> builder.equiCriteria((List<ExpectedValueProvider<JoinNode.EquiJoinClause>>)ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol((Type)BigintType.BIGINT, "d"), new Symbol((Type)BigintType.BIGINT, "f")))).left(PlanMatchPattern.join(JoinType.INNER, leftJoinBuilder -> leftJoinBuilder.equiCriteria((List<ExpectedValueProvider<JoinNode.EquiJoinClause>>)ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol((Type)BigintType.BIGINT, "a2"), new Symbol((Type)BigintType.BIGINT, "c")))).left(PlanMatchPattern.join(JoinType.INNER, leftInnerJoinBuilder -> leftInnerJoinBuilder.equiCriteria((List<ExpectedValueProvider<JoinNode.EquiJoinClause>>)ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol((Type)BigintType.BIGINT, "a1"), new Symbol((Type)BigintType.BIGINT, "e")))).left(PlanMatchPattern.strictProject((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"a2", (Object)PlanMatchPattern.expression((Expression)new Call(NEGATION_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "a1")))), (Object)"a1", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "a1"))), PlanMatchPattern.values("a1"))).right(PlanMatchPattern.strictProject((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"e", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "e"))), PlanMatchPattern.values("e"))))).right(PlanMatchPattern.any(new PlanMatchPattern[0])))).right(PlanMatchPattern.strictProject((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"f", (Object)PlanMatchPattern.expression((Expression)new Call(NEGATION_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "b"))))), PlanMatchPattern.values("b"))))));
    }

    @Test
    public void testGiveUpOnComplexProjections() {
        JoinNode plan = this.joinNode(this.projectNode((PlanNode)this.joinNode((PlanNode)this.values("a1"), (PlanNode)this.values("b"), new String[0]), "a2", (Expression)new Call(ADD_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "a1"), (Object)new Reference((Type)BigintType.BIGINT, "b"))), "b", (Expression)new Reference((Type)BigintType.BIGINT, "b")), (PlanNode)this.values("c"), "a2", "c", "b", "c");
        Assertions.assertThat((int)JoinGraph.buildFrom((PlanNode)plan, (Lookup)Lookup.noLookup(), (PlanNodeIdAllocator)new PlanNodeIdAllocator()).size()).isEqualTo(2);
    }

    private Function<PlanBuilder, PlanNode> crossJoinAndJoin(JoinType secondJoinType) {
        return p -> {
            Symbol axSymbol = p.symbol("axSymbol");
            Symbol bySymbol = p.symbol("bySymbol");
            Symbol cxSymbol = p.symbol("cxSymbol");
            Symbol cySymbol = p.symbol("cySymbol");
            return p.join(JoinType.INNER, (PlanNode)p.join(secondJoinType, (PlanNode)p.values(axSymbol), (PlanNode)p.values(bySymbol), new JoinNode.EquiJoinClause[0]), (PlanNode)p.values(cxSymbol, cySymbol), new JoinNode.EquiJoinClause(axSymbol, cxSymbol), new JoinNode.EquiJoinClause(bySymbol, cySymbol));
        };
    }

    private PlanNode projectNode(PlanNode source, String symbol1, Expression expression1, String symbol2, Expression expression2) {
        return new ProjectNode(this.idAllocator.getNextId(), source, Assignments.of((Symbol)new Symbol((Type)BigintType.BIGINT, symbol1), (Expression)expression1, (Symbol)new Symbol((Type)BigintType.BIGINT, symbol2), (Expression)expression2));
    }

    private JoinNode joinNode(PlanNode left, PlanNode right, String ... symbols) {
        Preconditions.checkArgument((symbols.length % 2 == 0 ? 1 : 0) != 0);
        ImmutableList.Builder criteria = ImmutableList.builder();
        for (int i = 0; i < symbols.length; i += 2) {
            criteria.add((Object)new JoinNode.EquiJoinClause(new Symbol((Type)BigintType.BIGINT, symbols[i]), new Symbol((Type)BigintType.BIGINT, symbols[i + 1])));
        }
        return new JoinNode(this.idAllocator.getNextId(), JoinType.INNER, left, right, (List)criteria.build(), left.getOutputSymbols(), right.getOutputSymbols(), false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), (Map)ImmutableMap.of(), Optional.empty());
    }

    private ValuesNode values(String ... symbols) {
        return new ValuesNode(this.idAllocator.getNextId(), (List)Arrays.stream(symbols).map(name -> new Symbol((Type)BigintType.BIGINT, name)).collect(ImmutableList.toImmutableList()));
    }
}

