/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.SessionTestUtils;
import io.trino.cost.CostProvider;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatsProvider;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.cost.TaskCountEstimator;
import io.trino.execution.warnings.WarningCollector;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.DistinctAggregationStrategyChooser;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;

@TestInstance(value=TestInstance.Lifecycle.PER_CLASS)
@Execution(value=ExecutionMode.CONCURRENT)
public class TestDistinctAggregationStrategyChooser {
    private static final int NODE_COUNT = 6;
    private static final TaskCountEstimator TASK_COUNT_ESTIMATOR = new TaskCountEstimator(() -> 6);

    @Test
    public void testSingleStepPreferredForHighCardinalitySingleGroupByKey() {
        DistinctAggregationStrategyChooser aggregationStrategyChooser = DistinctAggregationStrategyChooser.createDistinctAggregationStrategyChooser((TaskCountEstimator)TASK_COUNT_ESTIMATOR);
        SymbolAllocator symbolAllocator = new SymbolAllocator();
        Symbol groupingKey = symbolAllocator.newSymbol("groupingKey", (Type)BigintType.BIGINT);
        ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1000000);
        AggregationNode aggregationNode = AggregationNode.singleAggregation((PlanNodeId)new PlanNodeId("aggregation"), (PlanNode)source, (Map)ImmutableMap.of(), (AggregationNode.GroupingSetDescriptor)AggregationNode.singleGroupingSet((List)ImmutableList.of((Object)groupingKey)));
        Rule.Context context = TestDistinctAggregationStrategyChooser.context((Map<PlanNode, PlanNodeStatsEstimate>)ImmutableMap.of((Object)source, (Object)new PlanNodeStatsEstimate(1000000.0, (Map)ImmutableMap.of((Object)groupingKey, (Object)SymbolStatsEstimate.builder().setDistinctValuesCount(1000000.0).build()))), symbolAllocator);
        Assertions.assertThat((boolean)aggregationStrategyChooser.shouldAddMarkDistinct(aggregationNode, context.getSession(), context.getStatsProvider())).isFalse();
    }

    @Test
    public void testSingleStepPreferredForHighCardinalityMultipleGroupByKeys() {
        DistinctAggregationStrategyChooser aggregationStrategyChooser = DistinctAggregationStrategyChooser.createDistinctAggregationStrategyChooser((TaskCountEstimator)TASK_COUNT_ESTIMATOR);
        SymbolAllocator symbolAllocator = new SymbolAllocator();
        Symbol lowCardinalityGroupingKey = symbolAllocator.newSymbol("lowCardinalityGroupingKey", (Type)BigintType.BIGINT);
        Symbol highCardinalityGroupingKey = symbolAllocator.newSymbol("highCardinalityGroupingKey", (Type)BigintType.BIGINT);
        ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1000000);
        AggregationNode aggregationNode = AggregationNode.singleAggregation((PlanNodeId)new PlanNodeId("aggregation"), (PlanNode)source, (Map)ImmutableMap.of(), (AggregationNode.GroupingSetDescriptor)AggregationNode.singleGroupingSet((List)ImmutableList.of((Object)lowCardinalityGroupingKey, (Object)highCardinalityGroupingKey)));
        Rule.Context context = TestDistinctAggregationStrategyChooser.context((Map<PlanNode, PlanNodeStatsEstimate>)ImmutableMap.of((Object)source, (Object)new PlanNodeStatsEstimate(1000000.0, (Map)ImmutableMap.of((Object)lowCardinalityGroupingKey, (Object)SymbolStatsEstimate.builder().setDistinctValuesCount(10.0).build(), (Object)highCardinalityGroupingKey, (Object)SymbolStatsEstimate.builder().setDistinctValuesCount(1000000.0).build()))), symbolAllocator);
        Assertions.assertThat((boolean)aggregationStrategyChooser.shouldAddMarkDistinct(aggregationNode, context.getSession(), context.getStatsProvider())).isFalse();
    }

    @Test
    public void testPreAggregatePreferredForLowCardinality2GroupByKeys() {
        DistinctAggregationStrategyChooser aggregationStrategyChooser = DistinctAggregationStrategyChooser.createDistinctAggregationStrategyChooser((TaskCountEstimator)TASK_COUNT_ESTIMATOR);
        SymbolAllocator symbolAllocator = new SymbolAllocator();
        ImmutableList groupingKeys = ImmutableList.of((Object)symbolAllocator.newSymbol("key1", (Type)BigintType.BIGINT), (Object)symbolAllocator.newSymbol("key2", (Type)BigintType.BIGINT));
        ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1000000);
        AggregationNode aggregationNode = AggregationNode.singleAggregation((PlanNodeId)new PlanNodeId("aggregation"), (PlanNode)source, (Map)ImmutableMap.of(), (AggregationNode.GroupingSetDescriptor)AggregationNode.singleGroupingSet((List)groupingKeys));
        Rule.Context context = TestDistinctAggregationStrategyChooser.context((Map<PlanNode, PlanNodeStatsEstimate>)ImmutableMap.of((Object)source, (Object)new PlanNodeStatsEstimate(1000000.0, (Map)groupingKeys.stream().collect(ImmutableMap.toImmutableMap(Function.identity(), symbol -> SymbolStatsEstimate.builder().setDistinctValuesCount(10.0).build())))), new SymbolAllocator());
        Assertions.assertThat((boolean)aggregationStrategyChooser.shouldUsePreAggregate(aggregationNode, context.getSession(), context.getStatsProvider())).isTrue();
        Assertions.assertThat((boolean)aggregationStrategyChooser.shouldAddMarkDistinct(aggregationNode, context.getSession(), context.getStatsProvider())).isTrue();
    }

    @Test
    public void testMarkDistinctPreferredForLowCardinality3GroupByKeys() {
        DistinctAggregationStrategyChooser aggregationStrategyChooser = DistinctAggregationStrategyChooser.createDistinctAggregationStrategyChooser((TaskCountEstimator)TASK_COUNT_ESTIMATOR);
        SymbolAllocator symbolAllocator = new SymbolAllocator();
        ImmutableList groupingKeys = ImmutableList.of((Object)symbolAllocator.newSymbol("key1", (Type)BigintType.BIGINT), (Object)symbolAllocator.newSymbol("key2", (Type)BigintType.BIGINT), (Object)symbolAllocator.newSymbol("key3", (Type)BigintType.BIGINT));
        ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1000000);
        AggregationNode aggregationNode = AggregationNode.singleAggregation((PlanNodeId)new PlanNodeId("aggregation"), (PlanNode)source, (Map)ImmutableMap.of(), (AggregationNode.GroupingSetDescriptor)AggregationNode.singleGroupingSet((List)groupingKeys));
        Rule.Context context = TestDistinctAggregationStrategyChooser.context((Map<PlanNode, PlanNodeStatsEstimate>)ImmutableMap.of((Object)source, (Object)new PlanNodeStatsEstimate(1000000.0, (Map)groupingKeys.stream().collect(ImmutableMap.toImmutableMap(Function.identity(), symbol -> SymbolStatsEstimate.builder().setDistinctValuesCount(10.0).build())))), new SymbolAllocator());
        Assertions.assertThat((boolean)aggregationStrategyChooser.shouldAddMarkDistinct(aggregationNode, context.getSession(), context.getStatsProvider())).isTrue();
    }

    private static Rule.Context context(final Map<PlanNode, PlanNodeStatsEstimate> stats, final SymbolAllocator symbolAllocator) {
        final PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator();
        return new Rule.Context(){

            public Lookup getLookup() {
                return Lookup.noLookup();
            }

            public PlanNodeIdAllocator getIdAllocator() {
                return planNodeIdAllocator;
            }

            public SymbolAllocator getSymbolAllocator() {
                return symbolAllocator;
            }

            public Session getSession() {
                return SessionTestUtils.TEST_SESSION;
            }

            public StatsProvider getStatsProvider() {
                return stats::get;
            }

            public CostProvider getCostProvider() {
                throw new UnsupportedOperationException();
            }

            public void checkTimeoutNotExhausted() {
                throw new UnsupportedOperationException();
            }

            public WarningCollector getWarningCollector() {
                throw new UnsupportedOperationException();
            }
        };
    }
}

