/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.Plugin;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.RealType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.AggregationFunction;
import io.trino.sql.planner.assertions.ExpectedValueProvider;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.SingleDistinctAggregationToGroupBy;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SymbolReference;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.junit.jupiter.api.Test;

public class TestSingleDistinctAggregationToGroupBy
extends BaseRuleTest {
    public TestSingleDistinctAggregationToGroupBy() {
        super(new Plugin[0]);
    }

    @Test
    public void testNoDistinct() {
        this.tester().assertThat((Rule<?>)new SingleDistinctAggregationToGroupBy()).on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of((Object)new SymbolReference("input1"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).source((PlanNode)p.values(p.symbol("input1"), p.symbol("input2"))))).doesNotFire();
    }

    @Test
    public void testMultipleDistincts() {
        this.tester().assertThat((Rule<?>)new SingleDistinctAggregationToGroupBy()).on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, (List<Expression>)ImmutableList.of((Object)new SymbolReference("input1"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).addAggregation(p.symbol("output2"), PlanBuilder.aggregation("count", true, (List<Expression>)ImmutableList.of((Object)new SymbolReference("input2"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).source((PlanNode)p.values(p.symbol("input1"), p.symbol("input2"))))).doesNotFire();
    }

    @Test
    public void testMixedDistinctAndNonDistinct() {
        this.tester().assertThat((Rule<?>)new SingleDistinctAggregationToGroupBy()).on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, (List<Expression>)ImmutableList.of((Object)new SymbolReference("input1"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).addAggregation(p.symbol("output2"), PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of((Object)new SymbolReference("input2"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).source((PlanNode)p.values(p.symbol("input1"), p.symbol("input2"))))).doesNotFire();
    }

    @Test
    public void testDistinctWithFilter() {
        this.tester().assertThat((Rule<?>)new SingleDistinctAggregationToGroupBy()).on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(p.symbol("output"), PlanBuilder.aggregation("count", true, (List<Expression>)ImmutableList.of((Object)new SymbolReference("input1")), new Symbol("filter1")), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).source((PlanNode)p.project(Assignments.builder().putIdentity(p.symbol("input1")).putIdentity(p.symbol("input2")).put(p.symbol("filter1"), PlanBuilder.expression("input2 > 0")).build(), (PlanNode)p.values(p.symbol("input1"), p.symbol("input2")))))).doesNotFire();
    }

    @Test
    public void testSingleAggregation() {
        this.tester().assertThat((Rule<?>)new SingleDistinctAggregationToGroupBy()).on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(p.symbol("output"), PlanBuilder.aggregation("count", true, (List<Expression>)ImmutableList.of((Object)new SymbolReference("input"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).source((PlanNode)p.values(p.symbol("input"))))).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(Optional.of("output"), PlanMatchPattern.aggregationFunction("count", (List<String>)ImmutableList.of((Object)"input"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("input"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values("input"))));
    }

    @Test
    public void testMultipleAggregations() {
        this.tester().assertThat((Rule<?>)new SingleDistinctAggregationToGroupBy()).on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(p.symbol("output1"), PlanBuilder.aggregation("count", true, (List<Expression>)ImmutableList.of((Object)new SymbolReference("input"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).addAggregation(p.symbol("output2"), PlanBuilder.aggregation("sum", true, (List<Expression>)ImmutableList.of((Object)new SymbolReference("input"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).source((PlanNode)p.values(p.symbol("input"))))).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.builder().put(Optional.of("output1"), PlanMatchPattern.aggregationFunction("count", (List<String>)ImmutableList.of((Object)"input"))).put(Optional.of("output2"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"input"))).buildOrThrow(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("input"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values("input"))));
    }

    @Test
    public void testMultipleInputs() {
        this.tester().assertThat((Rule<?>)new SingleDistinctAggregationToGroupBy()).on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(p.symbol("output1"), PlanBuilder.aggregation("corr", true, (List<Expression>)ImmutableList.of((Object)new SymbolReference("x"), (Object)new SymbolReference("y"))), (List<Type>)ImmutableList.of((Object)RealType.REAL, (Object)RealType.REAL)).addAggregation(p.symbol("output2"), PlanBuilder.aggregation("corr", true, (List<Expression>)ImmutableList.of((Object)new SymbolReference("y"), (Object)new SymbolReference("x"))), (List<Type>)ImmutableList.of((Object)RealType.REAL, (Object)RealType.REAL)).source((PlanNode)p.values(p.symbol("x"), p.symbol("y"))))).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.builder().put(Optional.of("output1"), PlanMatchPattern.aggregationFunction("corr", (List<String>)ImmutableList.of((Object)"x", (Object)"y"))).put(Optional.of("output2"), PlanMatchPattern.aggregationFunction("corr", (List<String>)ImmutableList.of((Object)"y", (Object)"x"))).buildOrThrow(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("x", "y"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values("x", "y"))));
    }
}

