/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.cost.BaseStatsCalculatorTest;
import com.facebook.presto.cost.ComposableStatsCalculator;
import com.facebook.presto.cost.FilterStatsCalculator;
import com.facebook.presto.cost.JoinStatsRule;
import com.facebook.presto.cost.PlanNodeStatsAssertion;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.ScalarStatsCalculator;
import com.facebook.presto.cost.StatsNormalizer;
import com.facebook.presto.cost.SymbolStatsEstimate;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.DoubleType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Optional;
import org.testng.annotations.Test;

public class TestJoinStatsRule
extends BaseStatsCalculatorTest {
    private static final String LEFT_JOIN_COLUMN = "left_join_column";
    private static final String LEFT_JOIN_COLUMN_2 = "left_join_column_2";
    private static final String RIGHT_JOIN_COLUMN = "right_join_column";
    private static final String RIGHT_JOIN_COLUMN_2 = "right_join_column_2";
    private static final String LEFT_OTHER_COLUMN = "left_column";
    private static final String RIGHT_OTHER_COLUMN = "right_column";
    private static final double LEFT_ROWS_COUNT = 500.0;
    private static final double RIGHT_ROWS_COUNT = 1000.0;
    private static final double TOTAL_ROWS_COUNT = 1500.0;
    private static final double LEFT_JOIN_COLUMN_NULLS = 0.3;
    private static final double LEFT_JOIN_COLUMN_2_NULLS = 0.4;
    private static final double LEFT_JOIN_COLUMN_NON_NULLS = 0.7;
    private static final double LEFT_JOIN_COLUMN_2_NON_NULLS = 0.6;
    private static final int LEFT_JOIN_COLUMN_NDV = 20;
    private static final int LEFT_JOIN_COLUMN_2_NDV = 50;
    private static final double RIGHT_JOIN_COLUMN_NULLS = 0.6;
    private static final double RIGHT_JOIN_COLUMN_2_NULLS = 0.8;
    private static final double RIGHT_JOIN_COLUMN_NON_NULLS = 0.4;
    private static final double RIGHT_JOIN_COLUMN_2_NON_NULLS = 0.19999999999999996;
    private static final int RIGHT_JOIN_COLUMN_NDV = 15;
    private static final int RIGHT_JOIN_COLUMN_2_NDV = 15;
    private static final SymbolStatistics LEFT_JOIN_COLUMN_STATS = TestJoinStatsRule.symbolStatistics("left_join_column", 0.0, 20.0, 0.3, 20.0);
    private static final SymbolStatistics LEFT_JOIN_COLUMN_2_STATS = TestJoinStatsRule.symbolStatistics("left_join_column_2", 0.0, 200.0, 0.4, 50.0);
    private static final SymbolStatistics LEFT_OTHER_COLUMN_STATS = TestJoinStatsRule.symbolStatistics("left_column", 42.0, 42.0, 0.42, 1.0);
    private static final SymbolStatistics RIGHT_JOIN_COLUMN_STATS = TestJoinStatsRule.symbolStatistics("right_join_column", 5.0, 20.0, 0.6, 15.0);
    private static final SymbolStatistics RIGHT_JOIN_COLUMN_2_STATS = TestJoinStatsRule.symbolStatistics("right_join_column_2", 100.0, 200.0, 0.8, 15.0);
    private static final SymbolStatistics RIGHT_OTHER_COLUMN_STATS = TestJoinStatsRule.symbolStatistics("right_column", 24.0, 24.0, 0.24, 1.0);
    private static final PlanNodeStatsEstimate LEFT_STATS = TestJoinStatsRule.planNodeStats(500.0, LEFT_JOIN_COLUMN_STATS, LEFT_OTHER_COLUMN_STATS);
    private static final PlanNodeStatsEstimate RIGHT_STATS = TestJoinStatsRule.planNodeStats(1000.0, RIGHT_JOIN_COLUMN_STATS, RIGHT_OTHER_COLUMN_STATS);
    private static final MetadataManager METADATA = MetadataManager.createTestMetadataManager();
    private static final StatsNormalizer NORMALIZER = new StatsNormalizer();
    private static final JoinStatsRule JOIN_STATS_RULE = new JoinStatsRule(new FilterStatsCalculator((Metadata)METADATA, new ScalarStatsCalculator((Metadata)METADATA), NORMALIZER), NORMALIZER, 1.0);

    @Test
    public void testStatsForInnerJoin() {
        double innerJoinRowCount = 7000.0;
        PlanNodeStatsEstimate innerJoinStats = TestJoinStatsRule.planNodeStats(innerJoinRowCount, TestJoinStatsRule.symbolStatistics(LEFT_JOIN_COLUMN, 5.0, 20.0, 0.0, 15.0), TestJoinStatsRule.symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, 0.0, 15.0), LEFT_OTHER_COLUMN_STATS, RIGHT_OTHER_COLUMN_STATS);
        this.assertJoinStats(JoinNode.Type.INNER, LEFT_STATS, RIGHT_STATS, innerJoinStats);
    }

    @Test
    public void testStatsForInnerJoinWithRepeatedClause() {
        double innerJoinRowCount = 6300.0;
        PlanNodeStatsEstimate innerJoinStats = TestJoinStatsRule.planNodeStats(innerJoinRowCount, TestJoinStatsRule.symbolStatistics(LEFT_JOIN_COLUMN, 5.0, 20.0, 0.0, 15.0), TestJoinStatsRule.symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, 0.0, 15.0), LEFT_OTHER_COLUMN_STATS, RIGHT_OTHER_COLUMN_STATS);
        this.tester().assertStatsFor(pb -> {
            Symbol leftJoinColumnSymbol = pb.symbol(LEFT_JOIN_COLUMN, (Type)BigintType.BIGINT);
            Symbol rightJoinColumnSymbol = pb.symbol(RIGHT_JOIN_COLUMN, (Type)DoubleType.DOUBLE);
            Symbol leftOtherColumnSymbol = pb.symbol(LEFT_OTHER_COLUMN, (Type)BigintType.BIGINT);
            Symbol rightOtherColumnSymbol = pb.symbol(RIGHT_OTHER_COLUMN, (Type)DoubleType.DOUBLE);
            return pb.join(JoinNode.Type.INNER, (PlanNode)pb.values(leftJoinColumnSymbol, leftOtherColumnSymbol), (PlanNode)pb.values(rightJoinColumnSymbol, rightOtherColumnSymbol), new JoinNode.EquiJoinClause(leftJoinColumnSymbol, rightJoinColumnSymbol), new JoinNode.EquiJoinClause(leftJoinColumnSymbol, rightJoinColumnSymbol));
        }).withSourceStats(0, LEFT_STATS).withSourceStats(1, RIGHT_STATS).check(stats -> stats.equalTo(innerJoinStats));
    }

    @Test
    public void testStatsForInnerJoinWithTwoEquiClauses() {
        double innerJoinRowCount = 1079.9999999999998;
        PlanNodeStatsEstimate innerJoinStats = TestJoinStatsRule.planNodeStats(innerJoinRowCount, TestJoinStatsRule.symbolStatistics(LEFT_JOIN_COLUMN, 5.0, 20.0, 0.0, 15.0), TestJoinStatsRule.symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, 0.0, 15.0), TestJoinStatsRule.symbolStatistics(LEFT_JOIN_COLUMN_2, 100.0, 200.0, 0.0, 15.0), TestJoinStatsRule.symbolStatistics(RIGHT_JOIN_COLUMN_2, 100.0, 200.0, 0.0, 15.0));
        this.tester().assertStatsFor(pb -> {
            Symbol leftJoinColumnSymbol = pb.symbol(LEFT_JOIN_COLUMN, (Type)BigintType.BIGINT);
            Symbol rightJoinColumnSymbol = pb.symbol(RIGHT_JOIN_COLUMN, (Type)DoubleType.DOUBLE);
            Symbol leftJoinColumnSymbol2 = pb.symbol(LEFT_JOIN_COLUMN_2, (Type)BigintType.BIGINT);
            Symbol rightJoinColumnSymbol2 = pb.symbol(RIGHT_JOIN_COLUMN_2, (Type)DoubleType.DOUBLE);
            return pb.join(JoinNode.Type.INNER, (PlanNode)pb.values(leftJoinColumnSymbol, leftJoinColumnSymbol2), (PlanNode)pb.values(rightJoinColumnSymbol, rightJoinColumnSymbol2), new JoinNode.EquiJoinClause(leftJoinColumnSymbol2, rightJoinColumnSymbol2), new JoinNode.EquiJoinClause(leftJoinColumnSymbol, rightJoinColumnSymbol));
        }).withSourceStats(0, TestJoinStatsRule.planNodeStats(500.0, LEFT_JOIN_COLUMN_STATS, LEFT_JOIN_COLUMN_2_STATS)).withSourceStats(1, TestJoinStatsRule.planNodeStats(1000.0, RIGHT_JOIN_COLUMN_STATS, RIGHT_JOIN_COLUMN_2_STATS)).check(stats -> stats.equalTo(innerJoinStats));
    }

    @Test
    public void testStatsForInnerJoinWithTwoEquiClausesAndNonEqualityFunction() {
        double innerJoinRowCount = 359.9999999639999;
        PlanNodeStatsEstimate innerJoinStats = TestJoinStatsRule.planNodeStats(innerJoinRowCount, TestJoinStatsRule.symbolStatistics(LEFT_JOIN_COLUMN, 5.0, 10.0, 0.0, 4.9999999995), TestJoinStatsRule.symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, 0.0, 15.0), TestJoinStatsRule.symbolStatistics(LEFT_JOIN_COLUMN_2, 100.0, 200.0, 0.0, 15.0), TestJoinStatsRule.symbolStatistics(RIGHT_JOIN_COLUMN_2, 100.0, 200.0, 0.0, 15.0));
        this.tester().assertStatsFor(pb -> {
            Symbol leftJoinColumnSymbol = pb.symbol(LEFT_JOIN_COLUMN, (Type)BigintType.BIGINT);
            Symbol rightJoinColumnSymbol = pb.symbol(RIGHT_JOIN_COLUMN, (Type)DoubleType.DOUBLE);
            Symbol leftJoinColumnSymbol2 = pb.symbol(LEFT_JOIN_COLUMN_2, (Type)BigintType.BIGINT);
            Symbol rightJoinColumnSymbol2 = pb.symbol(RIGHT_JOIN_COLUMN_2, (Type)DoubleType.DOUBLE);
            ComparisonExpression leftJoinColumnLessThanTen = new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, (Expression)leftJoinColumnSymbol.toSymbolReference(), (Expression)new LongLiteral("10"));
            return pb.join(JoinNode.Type.INNER, (PlanNode)pb.values(leftJoinColumnSymbol, leftJoinColumnSymbol2), (PlanNode)pb.values(rightJoinColumnSymbol, rightJoinColumnSymbol2), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(leftJoinColumnSymbol2, rightJoinColumnSymbol2), (Object)new JoinNode.EquiJoinClause(leftJoinColumnSymbol, rightJoinColumnSymbol)), (List<Symbol>)ImmutableList.of((Object)leftJoinColumnSymbol, (Object)leftJoinColumnSymbol2, (Object)rightJoinColumnSymbol, (Object)rightJoinColumnSymbol2), Optional.of(leftJoinColumnLessThanTen));
        }).withSourceStats(0, TestJoinStatsRule.planNodeStats(500.0, LEFT_JOIN_COLUMN_STATS, LEFT_JOIN_COLUMN_2_STATS)).withSourceStats(1, TestJoinStatsRule.planNodeStats(1000.0, RIGHT_JOIN_COLUMN_STATS, RIGHT_JOIN_COLUMN_2_STATS)).check(stats -> stats.equalTo(innerJoinStats));
    }

    @Test
    public void testJoinComplementStats() {
        PlanNodeStatsEstimate joinComplementStats = TestJoinStatsRule.planNodeStats(237.5, TestJoinStatsRule.symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, 0.631578947368421, 5.0), LEFT_OTHER_COLUMN_STATS);
        PlanNodeStatsAssertion.assertThat(JOIN_STATS_RULE.calculateJoinComplementStats(Optional.empty(), (List)ImmutableList.of((Object)new JoinNode.EquiJoinClause(new Symbol(LEFT_JOIN_COLUMN), new Symbol(RIGHT_JOIN_COLUMN))), LEFT_STATS, RIGHT_STATS)).equalTo(joinComplementStats);
    }

    @Test
    public void testRightJoinComplementStats() {
        PlanNodeStatsEstimate joinComplementStats = TestJoinStatsRule.planNodeStats(600.0, TestJoinStatsRule.symbolStatistics(RIGHT_JOIN_COLUMN, Double.NaN, Double.NaN, 1.0, 0.0), RIGHT_OTHER_COLUMN_STATS);
        PlanNodeStatsAssertion.assertThat(JOIN_STATS_RULE.calculateJoinComplementStats(Optional.empty(), (List)ImmutableList.of((Object)new JoinNode.EquiJoinClause(new Symbol(RIGHT_JOIN_COLUMN), new Symbol(LEFT_JOIN_COLUMN))), RIGHT_STATS, LEFT_STATS)).equalTo(joinComplementStats);
    }

    @Test
    public void testLeftJoinComplementStatsWithNoClauses() {
        PlanNodeStatsAssertion.assertThat(JOIN_STATS_RULE.calculateJoinComplementStats(Optional.empty(), (List)ImmutableList.of(), LEFT_STATS, RIGHT_STATS)).equalTo(LEFT_STATS.mapOutputRowCount(rowCount -> 0.0));
    }

    @Test
    public void testLeftJoinComplementStatsWithMultipleClauses() {
        PlanNodeStatsEstimate joinComplementStats = TestJoinStatsRule.planNodeStats(237.5, TestJoinStatsRule.symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, 0.631578947368421, 5.0), LEFT_OTHER_COLUMN_STATS).mapOutputRowCount(rowCount -> rowCount / 0.9);
        PlanNodeStatsAssertion.assertThat(JOIN_STATS_RULE.calculateJoinComplementStats(Optional.empty(), (List)ImmutableList.of((Object)new JoinNode.EquiJoinClause(new Symbol(LEFT_JOIN_COLUMN), new Symbol(RIGHT_JOIN_COLUMN)), (Object)new JoinNode.EquiJoinClause(new Symbol(LEFT_OTHER_COLUMN), new Symbol(RIGHT_OTHER_COLUMN))), LEFT_STATS, RIGHT_STATS)).equalTo(joinComplementStats);
    }

    @Test
    public void testStatsForLeftAndRightJoin() {
        double innerJoinRowCount = 7000.0;
        double joinComplementRowCount = 237.5;
        double joinComplementColumnNulls = 0.631578947368421;
        double totalRowCount = innerJoinRowCount + joinComplementRowCount;
        PlanNodeStatsEstimate leftJoinStats = TestJoinStatsRule.planNodeStats(totalRowCount, TestJoinStatsRule.symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, joinComplementColumnNulls * joinComplementRowCount / totalRowCount, 20.0), LEFT_OTHER_COLUMN_STATS, TestJoinStatsRule.symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, joinComplementRowCount / totalRowCount, 15.0), TestJoinStatsRule.symbolStatistics(RIGHT_OTHER_COLUMN, 24.0, 24.0, (0.24 * innerJoinRowCount + joinComplementRowCount) / totalRowCount, 1.0));
        this.assertJoinStats(JoinNode.Type.LEFT, LEFT_STATS, RIGHT_STATS, leftJoinStats);
        this.assertJoinStats(JoinNode.Type.RIGHT, RIGHT_JOIN_COLUMN, RIGHT_OTHER_COLUMN, LEFT_JOIN_COLUMN, LEFT_OTHER_COLUMN, RIGHT_STATS, LEFT_STATS, leftJoinStats);
    }

    @Test
    public void testStatsForFullJoin() {
        double innerJoinRowCount = 7000.0;
        double leftJoinComplementRowCount = 237.5;
        double leftJoinComplementColumnNulls = 0.631578947368421;
        double rightJoinComplementRowCount = 600.0;
        double rightJoinComplementColumnNulls = 1.0;
        double totalRowCount = innerJoinRowCount + leftJoinComplementRowCount + rightJoinComplementRowCount;
        PlanNodeStatsEstimate leftJoinStats = TestJoinStatsRule.planNodeStats(totalRowCount, TestJoinStatsRule.symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, (leftJoinComplementColumnNulls * leftJoinComplementRowCount + rightJoinComplementRowCount) / totalRowCount, 20.0), TestJoinStatsRule.symbolStatistics(LEFT_OTHER_COLUMN, 42.0, 42.0, (0.42 * (innerJoinRowCount + leftJoinComplementRowCount) + rightJoinComplementRowCount) / totalRowCount, 1.0), TestJoinStatsRule.symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, (rightJoinComplementColumnNulls * rightJoinComplementRowCount + leftJoinComplementRowCount) / totalRowCount, 15.0), TestJoinStatsRule.symbolStatistics(RIGHT_OTHER_COLUMN, 24.0, 24.0, (0.24 * (innerJoinRowCount + rightJoinComplementRowCount) + leftJoinComplementRowCount) / totalRowCount, 1.0));
        this.assertJoinStats(JoinNode.Type.FULL, LEFT_STATS, RIGHT_STATS, leftJoinStats);
    }

    @Test
    public void testAddJoinComplementStats() {
        double statsToAddNdv = 5.0;
        PlanNodeStatsEstimate statsToAdd = TestJoinStatsRule.planNodeStats(1000.0, TestJoinStatsRule.symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 5.0, 0.2, statsToAddNdv));
        PlanNodeStatsEstimate addedStats = TestJoinStatsRule.planNodeStats(1500.0, TestJoinStatsRule.symbolStatistics(LEFT_JOIN_COLUMN, 0.0, 20.0, 0.23333333333333334, 20.0), TestJoinStatsRule.symbolStatistics(LEFT_OTHER_COLUMN, 42.0, 42.0, 0.8066666666666666, 1.0));
        PlanNodeStatsAssertion.assertThat(JOIN_STATS_RULE.addJoinComplementStats(LEFT_STATS, LEFT_STATS, statsToAdd)).equalTo(addedStats);
    }

    private void assertJoinStats(JoinNode.Type joinType, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate resultStats) {
        this.assertJoinStats(joinType, LEFT_JOIN_COLUMN, LEFT_OTHER_COLUMN, RIGHT_JOIN_COLUMN, RIGHT_OTHER_COLUMN, leftStats, rightStats, resultStats);
    }

    private void assertJoinStats(JoinNode.Type joinType, String leftJoinColumn, String leftOtherColumn, String rightJoinColumn, String rightOtherColumn, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate resultStats) {
        this.tester().assertStatsFor(pb -> {
            Symbol leftJoinColumnSymbol = pb.symbol(leftJoinColumn, (Type)BigintType.BIGINT);
            Symbol rightJoinColumnSymbol = pb.symbol(rightJoinColumn, (Type)DoubleType.DOUBLE);
            Symbol leftOtherColumnSymbol = pb.symbol(leftOtherColumn, (Type)BigintType.BIGINT);
            Symbol rightOtherColumnSymbol = pb.symbol(rightOtherColumn, (Type)DoubleType.DOUBLE);
            return pb.join(joinType, (PlanNode)pb.values(leftJoinColumnSymbol, leftOtherColumnSymbol), (PlanNode)pb.values(rightJoinColumnSymbol, rightOtherColumnSymbol), new JoinNode.EquiJoinClause(leftJoinColumnSymbol, rightJoinColumnSymbol));
        }).withSourceStats(0, leftStats).withSourceStats(1, rightStats).check((ComposableStatsCalculator.Rule<?>)JOIN_STATS_RULE, stats -> stats.equalTo(resultStats));
    }

    private static PlanNodeStatsEstimate planNodeStats(double rowCount, SymbolStatistics ... symbolStatistics) {
        PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder().setOutputRowCount(rowCount);
        for (SymbolStatistics symbolStatistic : symbolStatistics) {
            builder.addSymbolStatistics(symbolStatistic.symbol, symbolStatistic.estimate);
        }
        return builder.build();
    }

    private static SymbolStatistics symbolStatistics(String symbolName, double low, double high, double nullsFraction, double ndv) {
        return new SymbolStatistics(new Symbol(symbolName), SymbolStatsEstimate.builder().setLowValue(low).setHighValue(high).setNullsFraction(nullsFraction).setDistinctValuesCount(ndv).build());
    }

    private static class SymbolStatistics {
        final Symbol symbol;
        final SymbolStatsEstimate estimate;

        SymbolStatistics(Symbol symbol, SymbolStatsEstimate estimate) {
            this.symbol = symbol;
            this.estimate = estimate;
        }
    }
}

