/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.cost.BaseStatsCalculatorTest;
import io.trino.cost.PlanNodeStatsAssertion;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.tree.Expression;
import io.trino.testing.TestingSession;
import java.util.List;
import java.util.function.Function;
import org.testng.annotations.Test;

public class TestFilterProjectAggregationStatsRule
extends BaseStatsCalculatorTest {
    private static final SymbolStatsEstimate SYMBOL_STATS_ESTIMATE_X = SymbolStatsEstimate.builder().setLowValue(0.0).setHighValue(100.0).setDistinctValuesCount(10.0).setNullsFraction(0.1).build();
    private static final SymbolStatsEstimate SYMBOL_STATS_ESTIMATE_Y = SymbolStatsEstimate.builder().setLowValue(0.0).setHighValue(10.0).setDistinctValuesCount(10.0).setNullsFraction(0.0).build();
    private static final Session APPROXIMATION_ENABLED = TestingSession.testSessionBuilder().setSystemProperty("non_estimatable_predicate_approximation_enabled", "true").build();
    private static final Session APPROXIMATION_DISABLED = TestingSession.testSessionBuilder().setSystemProperty("non_estimatable_predicate_approximation_enabled", "false").build();

    @Test
    public void testFilterOverAggregationStats() {
        Function<PlanBuilder, PlanNode> planProvider = pb -> pb.filter(PlanBuilder.expression("count_on_x > 0"), (PlanNode)pb.aggregation(ab -> ab.addAggregation(pb.symbol("count_on_x", (Type)BigintType.BIGINT), PlanBuilder.expression("count(x)"), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).singleGroupingSet(pb.symbol("y", (Type)BigintType.BIGINT)).source((PlanNode)pb.values(pb.symbol("x", (Type)BigintType.BIGINT), pb.symbol("y", (Type)BigintType.BIGINT)))));
        PlanNodeStatsEstimate sourceStats = PlanNodeStatsEstimate.builder().setOutputRowCount(100.0).addSymbolStatistics(new Symbol("y"), SYMBOL_STATS_ESTIMATE_Y).build();
        this.tester().assertStatsFor(APPROXIMATION_ENABLED, planProvider).withSourceStats(sourceStats).check(check -> check.outputRowsCount(90.0).symbolStatsUnknown("count_on_x"));
        this.tester().assertStatsFor(APPROXIMATION_DISABLED, planProvider).withSourceStats(sourceStats).check(PlanNodeStatsAssertion::outputRowsCountUnknown);
        this.tester().assertStatsFor(APPROXIMATION_ENABLED, planProvider).withSourceStats(PlanNodeStatsEstimate.builder().addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder().setDistinctValuesCount(50.0).build()).build()).check(PlanNodeStatsAssertion::outputRowsCountUnknown);
        this.tester().assertStatsFor(APPROXIMATION_ENABLED, pb -> pb.filter(PlanBuilder.expression("y = 1"), (PlanNode)pb.aggregation(ab -> ab.addAggregation(pb.symbol("count_on_x", (Type)BigintType.BIGINT), PlanBuilder.expression("count(x)"), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).singleGroupingSet(pb.symbol("y", (Type)BigintType.BIGINT)).source((PlanNode)pb.values(pb.symbol("x", (Type)BigintType.BIGINT), pb.symbol("y", (Type)BigintType.BIGINT)))))).withSourceStats(sourceStats).check(check -> check.outputRowsCount(10.0));
    }

    @Test
    public void testFilterAndProjectOverAggregationStats() {
        PlanNodeId aggregationId = new PlanNodeId("aggregation");
        PlanNodeStatsEstimate sourceStats = PlanNodeStatsEstimate.builder().setOutputRowCount(100.0).addSymbolStatistics(new Symbol("x"), SYMBOL_STATS_ESTIMATE_X).addSymbolStatistics(new Symbol("y"), SYMBOL_STATS_ESTIMATE_Y).build();
        this.tester().assertStatsFor(APPROXIMATION_ENABLED, pb -> {
            Symbol aggregatedOutput = pb.symbol("count_on_x", (Type)BigintType.BIGINT);
            return pb.filter(PlanBuilder.expression("count_on_x > 0"), (PlanNode)pb.project(Assignments.identity((Symbol[])new Symbol[]{aggregatedOutput}), (PlanNode)pb.aggregation(ab -> ab.addAggregation(aggregatedOutput, PlanBuilder.expression("count(x)"), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).singleGroupingSet(pb.symbol("y", (Type)BigintType.BIGINT)).source((PlanNode)pb.values(pb.symbol("x", (Type)BigintType.BIGINT), pb.symbol("y", (Type)BigintType.BIGINT))).nodeId(aggregationId))));
        }).withSourceStats(sourceStats).withSourceStats(aggregationId, PlanNodeStatsEstimate.builder().setOutputRowCount(50.0).build()).check(check -> check.outputRowsCount(45.0));
        this.tester().assertStatsFor(APPROXIMATION_ENABLED, pb -> {
            Symbol aggregatedOutput = pb.symbol("count_on_x", (Type)BigintType.BIGINT);
            return pb.filter(PlanBuilder.expression("count_on_x > 0"), (PlanNode)pb.project(Assignments.of((Symbol)pb.symbol("x_1"), (Expression)PlanBuilder.expression("x + 1"), (Symbol)aggregatedOutput, (Expression)aggregatedOutput.toSymbolReference()), (PlanNode)pb.aggregation(ab -> ab.addAggregation(aggregatedOutput, PlanBuilder.expression("count(x)"), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).singleGroupingSet(pb.symbol("y", (Type)BigintType.BIGINT)).source((PlanNode)pb.values(pb.symbol("x", (Type)BigintType.BIGINT), pb.symbol("y", (Type)BigintType.BIGINT))).nodeId(aggregationId))));
        }).withSourceStats(sourceStats).withSourceStats(aggregationId, PlanNodeStatsEstimate.builder().setOutputRowCount(50.0).build()).check(PlanNodeStatsAssertion::outputRowsCountUnknown);
    }
}

