/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.cost;

import com.google.common.collect.ImmutableList;
import io.prestosql.cost.BaseStatsCalculatorTest;
import io.prestosql.cost.PlanNodeStatsAssertion;
import io.prestosql.cost.PlanNodeStatsEstimate;
import io.prestosql.cost.StatsCalculatorAssertion;
import io.prestosql.cost.SymbolStatsEstimate;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder;
import io.prestosql.sql.planner.plan.PlanNode;
import java.util.List;
import java.util.function.Consumer;
import org.testng.annotations.Test;

public class TestAggregationStatsRule
extends BaseStatsCalculatorTest {
    @Test
    public void testAggregationWhenAllStatisticsAreKnown() {
        Consumer<PlanNodeStatsAssertion> outputRowCountAndZStatsAreCalculated = check -> check.outputRowsCount(15.0).symbolStats("z", symbolStatsAssertion -> symbolStatsAssertion.lowValue(10.0).highValue(15.0).distinctValuesCount(4.0).nullsFraction(0.2)).symbolStats("y", symbolStatsAssertion -> symbolStatsAssertion.lowValue(0.0).highValue(3.0).distinctValuesCount(3.0).nullsFraction(0.0));
        this.testAggregation(SymbolStatsEstimate.builder().setLowValue(10.0).setHighValue(15.0).setDistinctValuesCount(4.0).setNullsFraction(0.1).build()).check(outputRowCountAndZStatsAreCalculated);
        this.testAggregation(SymbolStatsEstimate.builder().setLowValue(10.0).setHighValue(15.0).setDistinctValuesCount(4.0).build()).check(outputRowCountAndZStatsAreCalculated);
        Consumer<PlanNodeStatsAssertion> outputRowsCountAndZStatsAreNotFullyCalculated = check -> check.outputRowsCountUnknown().symbolStats("z", symbolStatsAssertion -> symbolStatsAssertion.unknownRange().distinctValuesCountUnknown().nullsFractionUnknown()).symbolStats("y", symbolStatsAssertion -> symbolStatsAssertion.unknownRange().nullsFractionUnknown().distinctValuesCountUnknown());
        this.testAggregation(SymbolStatsEstimate.builder().setLowValue(10.0).setHighValue(15.0).setNullsFraction(0.1).build()).check(outputRowsCountAndZStatsAreNotFullyCalculated);
        this.testAggregation(SymbolStatsEstimate.builder().setLowValue(10.0).setHighValue(15.0).build()).check(outputRowsCountAndZStatsAreNotFullyCalculated);
    }

    private StatsCalculatorAssertion testAggregation(SymbolStatsEstimate zStats) {
        return this.tester().assertStatsFor(pb -> pb.aggregation(ab -> ab.addAggregation(pb.symbol("sum", (Type)BigintType.BIGINT), PlanBuilder.expression("sum(x)"), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).addAggregation(pb.symbol("count", (Type)BigintType.BIGINT), PlanBuilder.expression("count()"), (List<Type>)ImmutableList.of()).addAggregation(pb.symbol("count_on_x", (Type)BigintType.BIGINT), PlanBuilder.expression("count(x)"), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).singleGroupingSet(pb.symbol("y", (Type)BigintType.BIGINT), pb.symbol("z", (Type)BigintType.BIGINT)).source((PlanNode)pb.values(pb.symbol("x", (Type)BigintType.BIGINT), pb.symbol("y", (Type)BigintType.BIGINT), pb.symbol("z", (Type)BigintType.BIGINT))))).withSourceStats(PlanNodeStatsEstimate.builder().setOutputRowCount(100.0).addSymbolStatistics(new Symbol("x"), SymbolStatsEstimate.builder().setLowValue(1.0).setHighValue(10.0).setDistinctValuesCount(5.0).setNullsFraction(0.3).build()).addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder().setLowValue(0.0).setHighValue(3.0).setDistinctValuesCount(3.0).setNullsFraction(0.0).build()).addSymbolStatistics(new Symbol("z"), zStats).build()).check(check -> check.symbolStats("sum", symbolStatsAssertion -> symbolStatsAssertion.lowValueUnknown().highValueUnknown().distinctValuesCountUnknown().nullsFractionUnknown()).symbolStats("count", symbolStatsAssertion -> symbolStatsAssertion.lowValueUnknown().highValueUnknown().distinctValuesCountUnknown().nullsFractionUnknown()).symbolStats("count_on_x", symbolStatsAssertion -> symbolStatsAssertion.lowValueUnknown().highValueUnknown().distinctValuesCountUnknown().nullsFractionUnknown()).symbolStats("x", symbolStatsAssertion -> symbolStatsAssertion.lowValueUnknown().highValueUnknown().distinctValuesCountUnknown().nullsFractionUnknown()));
    }

    @Test
    public void testAggregationStatsCappedToInputRows() {
        this.tester().assertStatsFor(pb -> pb.aggregation(ab -> ab.addAggregation(pb.symbol("count_on_x", (Type)BigintType.BIGINT), PlanBuilder.expression("count(x)"), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).singleGroupingSet(pb.symbol("y", (Type)BigintType.BIGINT), pb.symbol("z", (Type)BigintType.BIGINT)).source((PlanNode)pb.values(pb.symbol("x", (Type)BigintType.BIGINT), pb.symbol("y", (Type)BigintType.BIGINT), pb.symbol("z", (Type)BigintType.BIGINT))))).withSourceStats(PlanNodeStatsEstimate.builder().setOutputRowCount(100.0).addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder().setDistinctValuesCount(50.0).build()).addSymbolStatistics(new Symbol("z"), SymbolStatsEstimate.builder().setDistinctValuesCount(50.0).build()).build()).check(check -> check.outputRowsCount(100.0));
    }
}

