/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.cost.TaskCountEstimator;
import io.trino.metadata.Metadata;
import io.trino.metadata.MetadataManager;
import io.trino.spi.Plugin;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.Type;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.AggregationFunction;
import io.trino.sql.planner.assertions.ExpectedValueProvider;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct;
import io.trino.sql.planner.iterative.rule.SingleDistinctAggregationToGroupBy;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.type.UnknownType;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.junit.jupiter.api.Test;

public class TestMultipleDistinctAggregationToMarkDistinct
extends BaseRuleTest {
    private static final int NODES_COUNT = 4;
    private static final TaskCountEstimator TASK_COUNT_ESTIMATOR = new TaskCountEstimator(() -> 4);
    private static final Metadata METADATA = MetadataManager.createTestMetadataManager();

    public TestMultipleDistinctAggregationToMarkDistinct() {
        super(new Plugin[0]);
    }

    @Test
    public void testNoDistinct() {
        this.tester().assertThat((Rule<?>)new SingleDistinctAggregationToGroupBy()).setSystemProperty("distinct_aggregations_strategy", "mark_distinct").on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(p.symbol("output1", (Type)BigintType.BIGINT), PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "input1"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).addAggregation(p.symbol("output2", (Type)BigintType.BIGINT), PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "input2"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).source((PlanNode)p.values(p.symbol("input1", (Type)BigintType.BIGINT), p.symbol("input2", (Type)BigintType.BIGINT))))).doesNotFire();
    }

    @Test
    public void testSingleDistinct() {
        this.tester().assertThat((Rule<?>)new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR, METADATA)).setSystemProperty("distinct_aggregations_strategy", "mark_distinct").on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(p.symbol("output1", (Type)BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "input1"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).source((PlanNode)p.values(p.symbol("input1", (Type)BigintType.BIGINT), p.symbol("input2", (Type)BigintType.BIGINT))))).doesNotFire();
    }

    @Test
    public void testMultipleAggregations() {
        this.tester().assertThat((Rule<?>)new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR, METADATA)).setSystemProperty("distinct_aggregations_strategy", "mark_distinct").on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(p.symbol("output1", (Type)BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "input"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).addAggregation(p.symbol("output2", (Type)BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "input"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).source((PlanNode)p.values(p.symbol("input", (Type)BigintType.BIGINT))))).doesNotFire();
    }

    @Test
    public void testDistinctWithFilter() {
        this.tester().assertThat((Rule<?>)new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR, METADATA)).setSystemProperty("distinct_aggregations_strategy", "mark_distinct").on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "input1")), new Symbol((Type)UnknownType.UNKNOWN, "filter1")), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).addAggregation(p.symbol("output2"), PlanBuilder.aggregation("count", true, (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "input2")), new Symbol((Type)UnknownType.UNKNOWN, "filter2")), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).source((PlanNode)p.project(Assignments.builder().putIdentity(p.symbol("input1", (Type)BigintType.BIGINT)).putIdentity(p.symbol("input2", (Type)BigintType.BIGINT)).put(p.symbol("filter1", (Type)BooleanType.BOOLEAN), (Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)IntegerType.INTEGER, "input2"), (Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L))).put(p.symbol("filter2", (Type)BooleanType.BOOLEAN), (Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)IntegerType.INTEGER, "input1"), (Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L))).build(), (PlanNode)p.values(p.symbol("input1", (Type)BigintType.BIGINT), p.symbol("input2", (Type)BigintType.BIGINT)))))).doesNotFire();
        this.tester().assertThat((Rule<?>)new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR, METADATA)).setSystemProperty("distinct_aggregations_strategy", "mark_distinct").on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(p.symbol("output1", (Type)BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "input1")), new Symbol((Type)UnknownType.UNKNOWN, "filter1")), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).addAggregation(p.symbol("output2", (Type)BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "input2"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).source((PlanNode)p.project(Assignments.builder().putIdentity(p.symbol("input1", (Type)BigintType.BIGINT)).putIdentity(p.symbol("input2", (Type)BigintType.BIGINT)).put(p.symbol("filter1", (Type)BooleanType.BOOLEAN), (Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)IntegerType.INTEGER, "input2"), (Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L))).put(p.symbol("filter2", (Type)BooleanType.BOOLEAN), (Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)IntegerType.INTEGER, "input1"), (Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L))).build(), (PlanNode)p.values(p.symbol("input1", (Type)BigintType.BIGINT), p.symbol("input2", (Type)BigintType.BIGINT)))))).doesNotFire();
    }

    @Test
    public void testGlobalAggregation() {
        this.tester().assertThat((Rule<?>)new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR, METADATA)).setSystemProperty("distinct_aggregations_strategy", "mark_distinct").on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(p.symbol("output1", (Type)BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "input1"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).addAggregation(p.symbol("output2", (Type)BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "input2"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).source((PlanNode)p.values(p.symbol("input1", (Type)BigintType.BIGINT), p.symbol("input2", (Type)BigintType.BIGINT))))).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(Optional.of("output1"), PlanMatchPattern.aggregationFunction("count", (List<String>)ImmutableList.of((Object)"input1")), Optional.of("output2"), PlanMatchPattern.aggregationFunction("count", (List<String>)ImmutableList.of((Object)"input2"))), (List<String>)ImmutableList.of(), (List<String>)ImmutableList.of((Object)"mark_input1", (Object)"mark_input2"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.markDistinct("mark_input2", (List<String>)ImmutableList.of((Object)"input2"), PlanMatchPattern.markDistinct("mark_input1", (List<String>)ImmutableList.of((Object)"input1"), PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"input1", (Object)0, (Object)"input2", (Object)1))))));
    }
}

