/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.cost.CostComparator;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.TaskCountEstimator;
import com.facebook.presto.cost.VariableStatsEstimate;
import com.facebook.presto.spi.Plugin;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.DetermineSemiJoinDistributionType;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.planner.iterative.rule.test.RuleAssert;
import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

@Test(singleThreaded=true)
public class TestDetermineSemiJoinDistributionType {
    private static final CostComparator COST_COMPARATOR = new CostComparator(1.0, 1.0, 1.0);
    private static final int NODES_COUNT = 4;
    private RuleTester tester;

    @BeforeClass
    public void setUp() {
        this.tester = new RuleTester((List<Plugin>)ImmutableList.of(), (Map<String, String>)ImmutableMap.of(), Optional.of(4));
    }

    @AfterClass(alwaysRun=true)
    public void tearDown() {
        this.tester.close();
        this.tester = null;
    }

    @Test
    public void testRetainDistributionType() {
        this.assertDetermineSemiJoinDistributionType().on(p -> p.semiJoin((PlanNode)p.values((List<VariableReferenceExpression>)ImmutableList.of((Object)p.variable("A1")), (List<List<RowExpression>>)ImmutableList.of(PlanBuilder.constantExpressions((Type)BigintType.BIGINT, 10L), PlanBuilder.constantExpressions((Type)BigintType.BIGINT, 11L))), (PlanNode)p.values((List<VariableReferenceExpression>)ImmutableList.of((Object)p.variable("B1")), (List<List<RowExpression>>)ImmutableList.of(PlanBuilder.constantExpressions((Type)BigintType.BIGINT, 50L), PlanBuilder.constantExpressions((Type)BigintType.BIGINT, 11L))), p.variable("A1"), p.variable("B1"), p.variable("output"), Optional.empty(), Optional.empty(), Optional.of(SemiJoinNode.DistributionType.REPLICATED))).doesNotFire();
    }

    @Test
    public void testPartitionWhenRequiredBySession() {
        int aRows = 10000;
        int bRows = 100;
        this.assertDetermineSemiJoinDistributionType().setSystemProperty("join_distribution_type", FeaturesConfig.JoinDistributionType.PARTITIONED.name()).overrideStats("valuesA", PlanNodeStatsEstimate.builder().setOutputRowCount((double)aRows).addVariableStatistics((Map)ImmutableMap.of((Object)new VariableReferenceExpression("A1", (Type)BigintType.BIGINT), (Object)new VariableStatsEstimate(0.0, 100.0, 0.0, 6400.0, 100.0))).build()).overrideStats("valuesB", PlanNodeStatsEstimate.builder().setOutputRowCount((double)bRows).addVariableStatistics((Map)ImmutableMap.of((Object)new VariableReferenceExpression("B1", (Type)BigintType.BIGINT), (Object)new VariableStatsEstimate(0.0, 100.0, 0.0, 640000.0, 100.0))).build()).on(p -> p.semiJoin((PlanNode)p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", (Type)BigintType.BIGINT)), (PlanNode)p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", (Type)BigintType.BIGINT)), p.variable("A1"), p.variable("B1"), p.variable("output"), Optional.empty(), Optional.empty(), Optional.empty())).matches(PlanMatchPattern.semiJoin("A1", "B1", "output", Optional.of(SemiJoinNode.DistributionType.PARTITIONED), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"A1", (Object)0)), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"B1", (Object)0))));
    }

    @Test
    public void testReplicatesWhenRequiredBySession() {
        int aRows = 10000;
        int bRows = 10000;
        this.assertDetermineSemiJoinDistributionType().setSystemProperty("join_distribution_type", FeaturesConfig.JoinDistributionType.BROADCAST.name()).setSystemProperty("join_max_broadcast_table_size", "1B").overrideStats("valuesA", PlanNodeStatsEstimate.builder().setOutputRowCount((double)aRows).addVariableStatistics((Map)ImmutableMap.of((Object)new VariableReferenceExpression("A1", (Type)BigintType.BIGINT), (Object)VariableStatsEstimate.unknown())).build()).overrideStats("valuesB", PlanNodeStatsEstimate.builder().setOutputRowCount((double)bRows).addVariableStatistics((Map)ImmutableMap.of((Object)new VariableReferenceExpression("B1", (Type)BigintType.BIGINT), (Object)VariableStatsEstimate.unknown())).build()).on(p -> p.semiJoin((PlanNode)p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", (Type)BigintType.BIGINT)), (PlanNode)p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", (Type)BigintType.BIGINT)), p.variable("A1"), p.variable("B1"), p.variable("output"), Optional.empty(), Optional.empty(), Optional.empty())).matches(PlanMatchPattern.semiJoin("A1", "B1", "output", Optional.of(SemiJoinNode.DistributionType.REPLICATED), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"A1", (Object)0)), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"B1", (Object)0))));
    }

    @Test
    public void testPartitionsWhenBothTablesEqual() {
        int aRows = 10000;
        int bRows = 10000;
        this.assertDetermineSemiJoinDistributionType().setSystemProperty("join_distribution_type", FeaturesConfig.JoinDistributionType.AUTOMATIC.name()).overrideStats("valuesA", PlanNodeStatsEstimate.builder().setOutputRowCount((double)aRows).addVariableStatistics((Map)ImmutableMap.of((Object)new VariableReferenceExpression("A1", (Type)BigintType.BIGINT), (Object)VariableStatsEstimate.unknown())).build()).overrideStats("valuesB", PlanNodeStatsEstimate.builder().setOutputRowCount((double)bRows).addVariableStatistics((Map)ImmutableMap.of((Object)new VariableReferenceExpression("B1", (Type)BigintType.BIGINT), (Object)VariableStatsEstimate.unknown())).build()).on(p -> p.semiJoin((PlanNode)p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", (Type)BigintType.BIGINT)), (PlanNode)p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", (Type)BigintType.BIGINT)), p.variable("A1"), p.variable("B1"), p.variable("output"), Optional.empty(), Optional.empty(), Optional.empty())).matches(PlanMatchPattern.semiJoin("A1", "B1", "output", Optional.of(SemiJoinNode.DistributionType.PARTITIONED), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"A1", (Object)0)), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"B1", (Object)0))));
    }

    @Test
    public void testReplicatesWhenFilterMuchSmaller() {
        int aRows = 10000;
        int bRows = 100;
        this.assertDetermineSemiJoinDistributionType().setSystemProperty("join_distribution_type", FeaturesConfig.JoinDistributionType.AUTOMATIC.name()).overrideStats("valuesA", PlanNodeStatsEstimate.builder().setOutputRowCount((double)aRows).addVariableStatistics((Map)ImmutableMap.of((Object)new VariableReferenceExpression("A1", (Type)BigintType.BIGINT), (Object)VariableStatsEstimate.unknown())).build()).overrideStats("valuesB", PlanNodeStatsEstimate.builder().setOutputRowCount((double)bRows).addVariableStatistics((Map)ImmutableMap.of((Object)new VariableReferenceExpression("B1", (Type)BigintType.BIGINT), (Object)VariableStatsEstimate.unknown())).build()).on(p -> p.semiJoin((PlanNode)p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", (Type)BigintType.BIGINT)), (PlanNode)p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", (Type)BigintType.BIGINT)), p.variable("A1"), p.variable("B1"), p.variable("output"), Optional.empty(), Optional.empty(), Optional.empty())).matches(PlanMatchPattern.semiJoin("A1", "B1", "output", Optional.of(SemiJoinNode.DistributionType.REPLICATED), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"A1", (Object)0)), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"B1", (Object)0))));
    }

    @Test
    public void testReplicatesWhenNotRestricted() {
        int aRows = 10000;
        int bRows = 10;
        PlanNodeStatsEstimate probeSideStatsEstimate = PlanNodeStatsEstimate.builder().setOutputRowCount((double)aRows).addVariableStatistics((Map)ImmutableMap.of((Object)new VariableReferenceExpression("A1", (Type)BigintType.BIGINT), (Object)new VariableStatsEstimate(0.0, 100.0, 0.0, 640000.0, 10.0))).build();
        PlanNodeStatsEstimate buildSideStatsEstimate = PlanNodeStatsEstimate.builder().setOutputRowCount((double)bRows).addVariableStatistics((Map)ImmutableMap.of((Object)new VariableReferenceExpression("B1", (Type)BigintType.BIGINT), (Object)new VariableStatsEstimate(0.0, 100.0, 0.0, 640000.0, 10.0))).build();
        this.assertDetermineSemiJoinDistributionType().setSystemProperty("join_distribution_type", FeaturesConfig.JoinDistributionType.AUTOMATIC.name()).setSystemProperty("join_max_broadcast_table_size", "100MB").overrideStats("valuesA", probeSideStatsEstimate).overrideStats("valuesB", buildSideStatsEstimate).on(p -> p.semiJoin((PlanNode)p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", (Type)BigintType.BIGINT)), (PlanNode)p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", (Type)BigintType.BIGINT)), p.variable("A1"), p.variable("B1"), p.variable("output"), Optional.empty(), Optional.empty(), Optional.empty())).matches(PlanMatchPattern.semiJoin("A1", "B1", "output", Optional.of(SemiJoinNode.DistributionType.REPLICATED), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"A1", (Object)0)), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"B1", (Object)0))));
        probeSideStatsEstimate = PlanNodeStatsEstimate.builder().setOutputRowCount((double)aRows).addVariableStatistics((Map)ImmutableMap.of((Object)new VariableReferenceExpression("A1", (Type)BigintType.BIGINT), (Object)new VariableStatsEstimate(0.0, 100.0, 0.0, 6.4E9, 10.0))).build();
        buildSideStatsEstimate = PlanNodeStatsEstimate.builder().setOutputRowCount((double)bRows).addVariableStatistics((Map)ImmutableMap.of((Object)new VariableReferenceExpression("B1", (Type)BigintType.BIGINT), (Object)new VariableStatsEstimate(0.0, 100.0, 0.0, 6.4E9, 10.0))).build();
        this.assertDetermineSemiJoinDistributionType().setSystemProperty("join_distribution_type", FeaturesConfig.JoinDistributionType.AUTOMATIC.name()).setSystemProperty("join_max_broadcast_table_size", "100MB").overrideStats("valuesA", probeSideStatsEstimate).overrideStats("valuesB", buildSideStatsEstimate).on(p -> p.semiJoin((PlanNode)p.values(new PlanNodeId("valuesA"), aRows, p.variable("A1", (Type)BigintType.BIGINT)), (PlanNode)p.values(new PlanNodeId("valuesB"), bRows, p.variable("B1", (Type)BigintType.BIGINT)), p.variable("A1"), p.variable("B1"), p.variable("output"), Optional.empty(), Optional.empty(), Optional.empty())).matches(PlanMatchPattern.semiJoin("A1", "B1", "output", Optional.of(SemiJoinNode.DistributionType.PARTITIONED), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"A1", (Object)0)), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"B1", (Object)0))));
    }

    private RuleAssert assertDetermineSemiJoinDistributionType() {
        return this.assertDetermineSemiJoinDistributionType(COST_COMPARATOR);
    }

    private RuleAssert assertDetermineSemiJoinDistributionType(CostComparator costComparator) {
        return this.tester.assertThat((Rule)new DetermineSemiJoinDistributionType(costComparator, new TaskCountEstimator(() -> 4)));
    }
}

