/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.cost.CostComparator;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.cost.TaskCountEstimator;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.planner.OptimizerConfig;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.DetermineSemiJoinDistributionType;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.iterative.rule.test.RuleBuilder;
import io.trino.sql.planner.iterative.rule.test.RuleTester;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.type.UnknownType;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;

@TestInstance(value=TestInstance.Lifecycle.PER_CLASS)
@Execution(value=ExecutionMode.CONCURRENT)
public class TestDetermineSemiJoinDistributionType {
    private static final CostComparator COST_COMPARATOR = new CostComparator(1.0, 1.0, 1.0);
    private static final int NODES_COUNT = 4;
    private RuleTester tester;

    @BeforeAll
    public void setUp() {
        this.tester = RuleTester.builder().withNodeCountForStats(4).build();
    }

    @AfterAll
    public void tearDown() {
        this.tester.close();
        this.tester = null;
    }

    @Test
    public void testRetainDistributionType() {
        this.assertDetermineSemiJoinDistributionType().on(p -> p.semiJoin((PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("A1")), (List<List<Expression>>)ImmutableList.of((Object)ImmutableList.of((Object)new Constant((Type)IntegerType.INTEGER, (Object)10L)), (Object)ImmutableList.of((Object)new Constant((Type)IntegerType.INTEGER, (Object)11L)))), (PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("B1")), (List<List<Expression>>)ImmutableList.of((Object)ImmutableList.of((Object)new Constant((Type)IntegerType.INTEGER, (Object)50L)), (Object)ImmutableList.of((Object)new Constant((Type)IntegerType.INTEGER, (Object)11L)))), p.symbol("A1"), p.symbol("B1"), p.symbol("output"), Optional.empty(), Optional.empty(), Optional.of(SemiJoinNode.DistributionType.REPLICATED))).doesNotFire();
    }

    @Test
    public void testPartitionWhenRequiredBySession() {
        VarcharType symbolType = VarcharType.createUnboundedVarcharType();
        int aRows = 10000;
        int bRows = 100;
        this.assertDetermineSemiJoinDistributionType().setSystemProperty("join_distribution_type", OptimizerConfig.JoinDistributionType.PARTITIONED.name()).overrideStats("valuesA", PlanNodeStatsEstimate.builder().setOutputRowCount((double)aRows).addSymbolStatistics((Map)ImmutableMap.of((Object)new Symbol((Type)UnknownType.UNKNOWN, "A1"), (Object)new SymbolStatsEstimate(0.0, 100.0, 0.0, 6400.0, 100.0))).build()).overrideStats("valuesB", PlanNodeStatsEstimate.builder().setOutputRowCount((double)bRows).addSymbolStatistics((Map)ImmutableMap.of((Object)new Symbol((Type)UnknownType.UNKNOWN, "B1"), (Object)new SymbolStatsEstimate(0.0, 100.0, 0.0, 640000.0, 100.0))).build()).on(arg_0 -> TestDetermineSemiJoinDistributionType.lambda$testPartitionWhenRequiredBySession$1((Type)symbolType, aRows, bRows, arg_0)).matches(PlanMatchPattern.semiJoin("A1", "B1", "output", Optional.of(SemiJoinNode.DistributionType.PARTITIONED), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"A1", (Object)0)), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"B1", (Object)0))));
    }

    @Test
    public void testReplicatesWhenRequiredBySession() {
        int aRows = 10000;
        int bRows = 10000;
        this.assertDetermineSemiJoinDistributionType().setSystemProperty("join_distribution_type", OptimizerConfig.JoinDistributionType.BROADCAST.name()).setSystemProperty("join_max_broadcast_table_size", "1B").overrideStats("valuesA", PlanNodeStatsEstimate.builder().setOutputRowCount((double)aRows).addSymbolStatistics((Map)ImmutableMap.of((Object)new Symbol((Type)UnknownType.UNKNOWN, "A1"), (Object)SymbolStatsEstimate.unknown())).build()).overrideStats("valuesB", PlanNodeStatsEstimate.builder().setOutputRowCount((double)bRows).addSymbolStatistics((Map)ImmutableMap.of((Object)new Symbol((Type)UnknownType.UNKNOWN, "B1"), (Object)SymbolStatsEstimate.unknown())).build()).on(p -> p.semiJoin((PlanNode)p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", (Type)BigintType.BIGINT)), (PlanNode)p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", (Type)BigintType.BIGINT)), p.symbol("A1"), p.symbol("B1"), p.symbol("output"), Optional.empty(), Optional.empty(), Optional.empty())).matches(PlanMatchPattern.semiJoin("A1", "B1", "output", Optional.of(SemiJoinNode.DistributionType.REPLICATED), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"A1", (Object)0)), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"B1", (Object)0))));
    }

    @Test
    public void testPartitionsWhenBothTablesEqual() {
        int aRows = 10000;
        int bRows = 10000;
        this.assertDetermineSemiJoinDistributionType().setSystemProperty("join_distribution_type", OptimizerConfig.JoinDistributionType.AUTOMATIC.name()).overrideStats("valuesA", PlanNodeStatsEstimate.builder().setOutputRowCount((double)aRows).addSymbolStatistics((Map)ImmutableMap.of((Object)new Symbol((Type)UnknownType.UNKNOWN, "A1"), (Object)SymbolStatsEstimate.unknown())).build()).overrideStats("valuesB", PlanNodeStatsEstimate.builder().setOutputRowCount((double)bRows).addSymbolStatistics((Map)ImmutableMap.of((Object)new Symbol((Type)UnknownType.UNKNOWN, "B1"), (Object)SymbolStatsEstimate.unknown())).build()).on(p -> p.semiJoin((PlanNode)p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", (Type)BigintType.BIGINT)), (PlanNode)p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", (Type)BigintType.BIGINT)), p.symbol("A1"), p.symbol("B1"), p.symbol("output"), Optional.empty(), Optional.empty(), Optional.empty())).matches(PlanMatchPattern.semiJoin("A1", "B1", "output", Optional.of(SemiJoinNode.DistributionType.PARTITIONED), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"A1", (Object)0)), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"B1", (Object)0))));
    }

    @Test
    public void testReplicatesWhenFilterMuchSmaller() {
        int aRows = 10000;
        int bRows = 100;
        this.assertDetermineSemiJoinDistributionType().setSystemProperty("join_distribution_type", OptimizerConfig.JoinDistributionType.AUTOMATIC.name()).overrideStats("valuesA", PlanNodeStatsEstimate.builder().setOutputRowCount((double)aRows).addSymbolStatistics((Map)ImmutableMap.of((Object)new Symbol((Type)UnknownType.UNKNOWN, "A1"), (Object)SymbolStatsEstimate.unknown())).build()).overrideStats("valuesB", PlanNodeStatsEstimate.builder().setOutputRowCount((double)bRows).addSymbolStatistics((Map)ImmutableMap.of((Object)new Symbol((Type)UnknownType.UNKNOWN, "B1"), (Object)SymbolStatsEstimate.unknown())).build()).on(p -> p.semiJoin((PlanNode)p.values(new PlanNodeId("valuesA"), aRows, p.symbol("A1", (Type)BigintType.BIGINT)), (PlanNode)p.values(new PlanNodeId("valuesB"), bRows, p.symbol("B1", (Type)BigintType.BIGINT)), p.symbol("A1"), p.symbol("B1"), p.symbol("output"), Optional.empty(), Optional.empty(), Optional.empty())).matches(PlanMatchPattern.semiJoin("A1", "B1", "output", Optional.of(SemiJoinNode.DistributionType.REPLICATED), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"A1", (Object)0)), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"B1", (Object)0))));
    }

    @Test
    public void testReplicatesWhenNotRestricted() {
        VarcharType symbolType = VarcharType.createUnboundedVarcharType();
        int aRows = 10000;
        int bRows = 10;
        PlanNodeStatsEstimate probeSideStatsEstimate = PlanNodeStatsEstimate.builder().setOutputRowCount((double)aRows).addSymbolStatistics((Map)ImmutableMap.of((Object)new Symbol((Type)symbolType, "A1"), (Object)new SymbolStatsEstimate(0.0, 100.0, 0.0, 640000.0, 10.0))).build();
        PlanNodeStatsEstimate buildSideStatsEstimate = PlanNodeStatsEstimate.builder().setOutputRowCount((double)bRows).addSymbolStatistics((Map)ImmutableMap.of((Object)new Symbol((Type)symbolType, "B1"), (Object)new SymbolStatsEstimate(0.0, 100.0, 0.0, 640000.0, 10.0))).build();
        this.assertDetermineSemiJoinDistributionType().setSystemProperty("join_distribution_type", OptimizerConfig.JoinDistributionType.AUTOMATIC.name()).setSystemProperty("join_max_broadcast_table_size", "100MB").overrideStats("valuesA", probeSideStatsEstimate).overrideStats("valuesB", buildSideStatsEstimate).on(arg_0 -> TestDetermineSemiJoinDistributionType.lambda$testReplicatesWhenNotRestricted$5((Type)symbolType, aRows, bRows, arg_0)).matches(PlanMatchPattern.semiJoin("A1", "B1", "output", Optional.of(SemiJoinNode.DistributionType.REPLICATED), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"A1", (Object)0)), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"B1", (Object)0))));
        probeSideStatsEstimate = PlanNodeStatsEstimate.builder().setOutputRowCount((double)aRows).addSymbolStatistics((Map)ImmutableMap.of((Object)new Symbol((Type)symbolType, "A1"), (Object)new SymbolStatsEstimate(0.0, 100.0, 0.0, 6.4E9, 10.0))).build();
        buildSideStatsEstimate = PlanNodeStatsEstimate.builder().setOutputRowCount((double)bRows).addSymbolStatistics((Map)ImmutableMap.of((Object)new Symbol((Type)symbolType, "B1"), (Object)new SymbolStatsEstimate(0.0, 100.0, 0.0, 6.4E9, 10.0))).build();
        this.assertDetermineSemiJoinDistributionType().setSystemProperty("join_distribution_type", OptimizerConfig.JoinDistributionType.AUTOMATIC.name()).setSystemProperty("join_max_broadcast_table_size", "100MB").overrideStats("valuesA", probeSideStatsEstimate).overrideStats("valuesB", buildSideStatsEstimate).on(arg_0 -> TestDetermineSemiJoinDistributionType.lambda$testReplicatesWhenNotRestricted$6((Type)symbolType, aRows, bRows, arg_0)).matches(PlanMatchPattern.semiJoin("A1", "B1", "output", Optional.of(SemiJoinNode.DistributionType.PARTITIONED), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"A1", (Object)0)), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"B1", (Object)0))));
    }

    @Test
    public void testReplicatesWhenSourceIsSmall() {
        VarcharType symbolType = VarcharType.createUnboundedVarcharType();
        int aRows = 10000;
        int bRows = 10;
        PlanNodeStatsEstimate probeSideStatsEstimate = PlanNodeStatsEstimate.builder().setOutputRowCount((double)aRows).addSymbolStatistics((Map)ImmutableMap.of((Object)new Symbol((Type)UnknownType.UNKNOWN, "A1"), (Object)new SymbolStatsEstimate(0.0, 100.0, 0.0, 6.4E9, 10.0))).build();
        PlanNodeStatsEstimate buildSideStatsEstimate = PlanNodeStatsEstimate.builder().setOutputRowCount((double)bRows).addSymbolStatistics((Map)ImmutableMap.of((Object)new Symbol((Type)UnknownType.UNKNOWN, "B1"), (Object)new SymbolStatsEstimate(0.0, 100.0, 0.0, 6.4E9, 10.0))).build();
        PlanNodeStatsEstimate buildSideSourceStatsEstimate = PlanNodeStatsEstimate.builder().setOutputRowCount((double)bRows).addSymbolStatistics((Map)ImmutableMap.of((Object)new Symbol((Type)UnknownType.UNKNOWN, "B1"), (Object)new SymbolStatsEstimate(0.0, 100.0, 0.0, 64.0, 10.0))).build();
        this.assertDetermineSemiJoinDistributionType().setSystemProperty("join_distribution_type", OptimizerConfig.JoinDistributionType.AUTOMATIC.name()).setSystemProperty("join_max_broadcast_table_size", "100MB").overrideStats("valuesA", probeSideStatsEstimate).overrideStats("filterB", buildSideStatsEstimate).overrideStats("valuesB", buildSideSourceStatsEstimate).on(arg_0 -> TestDetermineSemiJoinDistributionType.lambda$testReplicatesWhenSourceIsSmall$7((Type)symbolType, aRows, bRows, arg_0)).matches(PlanMatchPattern.semiJoin("A1", "B1", "output", Optional.of(SemiJoinNode.DistributionType.REPLICATED), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"A1", (Object)0)), PlanMatchPattern.filter((Expression)Booleans.TRUE, PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"B1", (Object)0)))));
    }

    private RuleBuilder assertDetermineSemiJoinDistributionType() {
        return this.assertDetermineSemiJoinDistributionType(COST_COMPARATOR);
    }

    private RuleBuilder assertDetermineSemiJoinDistributionType(CostComparator costComparator) {
        return this.tester.assertThat((Rule<?>)new DetermineSemiJoinDistributionType(costComparator, new TaskCountEstimator(() -> 4)));
    }

    private static /* synthetic */ PlanNode lambda$testReplicatesWhenSourceIsSmall$7(Type symbolType, int aRows, int bRows, PlanBuilder p) {
        Symbol a1 = p.symbol("A1", symbolType);
        Symbol b1 = p.symbol("B1", symbolType);
        return p.semiJoin((PlanNode)p.values(new PlanNodeId("valuesA"), aRows, a1), (PlanNode)p.filter(new PlanNodeId("filterB"), (Expression)Booleans.TRUE, (PlanNode)p.values(new PlanNodeId("valuesB"), bRows, b1)), a1, b1, p.symbol("output"), Optional.empty(), Optional.empty(), Optional.empty());
    }

    private static /* synthetic */ PlanNode lambda$testReplicatesWhenNotRestricted$6(Type symbolType, int aRows, int bRows, PlanBuilder p) {
        Symbol a1 = p.symbol("A1", symbolType);
        Symbol b1 = p.symbol("B1", symbolType);
        return p.semiJoin((PlanNode)p.values(new PlanNodeId("valuesA"), aRows, a1), (PlanNode)p.values(new PlanNodeId("valuesB"), bRows, b1), a1, b1, p.symbol("output", (Type)BigintType.BIGINT), Optional.empty(), Optional.empty(), Optional.empty());
    }

    private static /* synthetic */ PlanNode lambda$testReplicatesWhenNotRestricted$5(Type symbolType, int aRows, int bRows, PlanBuilder p) {
        Symbol a1 = p.symbol("A1", symbolType);
        Symbol b1 = p.symbol("B1", symbolType);
        return p.semiJoin((PlanNode)p.values(new PlanNodeId("valuesA"), aRows, a1), (PlanNode)p.values(new PlanNodeId("valuesB"), bRows, b1), a1, b1, p.symbol("output"), Optional.empty(), Optional.empty(), Optional.empty());
    }

    private static /* synthetic */ PlanNode lambda$testPartitionWhenRequiredBySession$1(Type symbolType, int aRows, int bRows, PlanBuilder p) {
        Symbol a1 = p.symbol("A1", symbolType);
        Symbol b1 = p.symbol("B1", symbolType);
        return p.semiJoin((PlanNode)p.values(new PlanNodeId("valuesA"), aRows, a1), (PlanNode)p.values(new PlanNodeId("valuesB"), bRows, b1), a1, b1, p.symbol("output"), Optional.empty(), Optional.empty(), Optional.empty());
    }
}

