/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.DoubleType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.spi.Plugin;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider;
import com.facebook.presto.sql.planner.assertions.ExpressionMatcher;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.RewriteAggregationIfToFilter;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.tree.FunctionCall;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import org.testng.annotations.Test;

public class TestRewriteAggregationIfToFilter
extends BaseRuleTest {
    public TestRewriteAggregationIfToFilter() {
        super(new Plugin[0]);
    }

    @Test
    public void testDoesNotFireForNonIf() {
        this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "filter_with_if").on(p -> {
            VariableReferenceExpression a = p.variable("a", (Type)BooleanType.BOOLEAN);
            VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
            return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr"), p.rowExpression("count(a)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("ds > '2021-07-01'")), (PlanNode)p.values(ds))));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireForIfWithElse() {
        this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "filter_with_if").on(p -> {
            VariableReferenceExpression a = p.variable("a");
            VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
            return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr"), p.rowExpression("count(a)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("IF(ds > '2021-07-01', 1, 2)")), (PlanNode)p.values(ds))));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireForNonDeterministicFunction() {
        this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "filter_with_if").on(p -> {
            VariableReferenceExpression a = p.variable("a", (Type)DoubleType.DOUBLE);
            VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
            return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr"), p.rowExpression("sum(a)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("IF(ds > '2021-07-01', random())")), (PlanNode)p.values(ds))));
        }).doesNotFire();
        this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "filter_with_if").on(p -> {
            VariableReferenceExpression a = p.variable("a", (Type)BigintType.BIGINT);
            VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
            return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr"), p.rowExpression("sum(a)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("IF(random() > DOUBLE '0.1', 1)")), (PlanNode)p.values(ds))));
        }).doesNotFire();
    }

    @Test
    public void testFireCount() {
        for (String strategy : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", strategy).on(p -> {
                VariableReferenceExpression a = p.variable("a");
                VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
                return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr"), p.rowExpression("count(a)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("IF(ds > '2021-07-01', 1)")), (PlanNode)p.values(ds))));
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr"), PlanMatchPattern.functionCall("count", (List<String>)ImmutableList.of((Object)"expr_0"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr"), (Object)new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"a", (Object)PlanMatchPattern.expression("IF(ds > '2021-07-01', 1)"), (Object)"greater_than", (Object)PlanMatchPattern.expression("ds > '2021-07-01'"), (Object)"expr_0", (Object)PlanMatchPattern.expression("1")), PlanMatchPattern.values("ds")))));
        }
    }

    @Test
    public void testUnwrapIf() {
        for (String strategy : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", strategy).on(p -> {
                VariableReferenceExpression a = p.variable("a");
                VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
                return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr"), p.rowExpression("count(a)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("IF(ds > '2021-07-01', 1)")), (PlanNode)p.values(ds))));
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr"), PlanMatchPattern.functionCall("count", (List<String>)ImmutableList.of((Object)"expr0"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr"), (Object)new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"a", (Object)PlanMatchPattern.expression("IF(ds > '2021-07-01', 1)"), (Object)"greater_than", (Object)PlanMatchPattern.expression("ds > '2021-07-01'"), (Object)"expr0", (Object)PlanMatchPattern.expression("1")), PlanMatchPattern.values("ds")))));
        }
    }

    @Test
    public void testFireMin() {
        for (String strategy : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", strategy).on(p -> {
                VariableReferenceExpression a = p.variable("a");
                VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
                VariableReferenceExpression column0 = p.variable("column0", (Type)BigintType.BIGINT);
                return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr0"), p.rowExpression("MIN(a)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("IF(ds > '2021-06-01', column0)")), (PlanNode)p.values(ds, column0))));
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("min", (List<String>)ImmutableList.of((Object)"column0_0"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr0"), (Object)new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project((Map<String, ExpressionMatcher>)new ImmutableMap.Builder().put((Object)"a", (Object)PlanMatchPattern.expression("IF(ds > '2021-06-01', column0)")).put((Object)"greater_than", (Object)PlanMatchPattern.expression("ds > '2021-06-01'")).put((Object)"column0_0", (Object)PlanMatchPattern.expression("column0")).build(), PlanMatchPattern.values("ds", "column0")))));
        }
    }

    @Test
    public void testFireMax() {
        for (String strategy : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", strategy).on(p -> {
                VariableReferenceExpression a = p.variable("a");
                VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
                VariableReferenceExpression column0 = p.variable("column0", (Type)BigintType.BIGINT);
                return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr0"), p.rowExpression("MAX(a)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("IF(ds > '2021-06-01', column0)")), (PlanNode)p.values(ds, column0))));
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("max", (List<String>)ImmutableList.of((Object)"column0_0"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr0"), (Object)new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project((Map<String, ExpressionMatcher>)new ImmutableMap.Builder().put((Object)"a", (Object)PlanMatchPattern.expression("IF(ds > '2021-06-01', column0)")).put((Object)"greater_than", (Object)PlanMatchPattern.expression("ds > '2021-06-01'")).put((Object)"column0_0", (Object)PlanMatchPattern.expression("column0")).build(), PlanMatchPattern.values("ds", "column0")))));
        }
    }

    @Test
    public void testFireArbitrary() {
        for (String strategy : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", strategy).on(p -> {
                VariableReferenceExpression a = p.variable("a");
                VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
                VariableReferenceExpression column0 = p.variable("column0", (Type)BigintType.BIGINT);
                return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr0"), p.rowExpression("ARBITRARY(a)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("IF(ds > '2021-06-01', column0)")), (PlanNode)p.values(ds, column0))));
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("arbitrary", (List<String>)ImmutableList.of((Object)"column0_0"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr0"), (Object)new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project((Map<String, ExpressionMatcher>)new ImmutableMap.Builder().put((Object)"a", (Object)PlanMatchPattern.expression("IF(ds > '2021-06-01', column0)")).put((Object)"greater_than", (Object)PlanMatchPattern.expression("ds > '2021-06-01'")).put((Object)"column0_0", (Object)PlanMatchPattern.expression("column0")).build(), PlanMatchPattern.values("ds", "column0")))));
        }
    }

    @Test
    public void testFireSum() {
        for (String strategy : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", strategy).on(p -> {
                VariableReferenceExpression a = p.variable("a");
                VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
                VariableReferenceExpression column0 = p.variable("column0", (Type)BigintType.BIGINT);
                return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr0"), p.rowExpression("SUM(a)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("IF(ds > '2021-06-01', column0)")), (PlanNode)p.values(ds, column0))));
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("sum", (List<String>)ImmutableList.of((Object)"column0_0"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr0"), (Object)new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project((Map<String, ExpressionMatcher>)new ImmutableMap.Builder().put((Object)"a", (Object)PlanMatchPattern.expression("IF(ds > '2021-06-01', column0)")).put((Object)"greater_than", (Object)PlanMatchPattern.expression("ds > '2021-06-01'")).put((Object)"column0_0", (Object)PlanMatchPattern.expression("column0")).build(), PlanMatchPattern.values("ds", "column0")))));
        }
    }

    @Test
    public void testDoesNotFireForMaxBy() {
        this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "filter_with_if").on(p -> {
            VariableReferenceExpression a = p.variable("a");
            VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
            VariableReferenceExpression column0 = p.variable("column0", (Type)BigintType.BIGINT);
            return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr0"), p.rowExpression("MAX_BY(a, a)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("IF(ds > '2021-06-01', column0)")), (PlanNode)p.values(ds, column0))));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireForMinBy() {
        this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "filter_with_if").on(p -> {
            VariableReferenceExpression a = p.variable("a");
            VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
            VariableReferenceExpression column0 = p.variable("column0", (Type)BigintType.BIGINT);
            return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr0"), p.rowExpression("MIN_BY(a, a)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("IF(ds > '2021-06-01', column0)")), (PlanNode)p.values(ds, column0))));
        }).doesNotFire();
    }

    @Test
    public void testFireTwoAggregations() {
        for (String strategy : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", strategy).on(p -> {
                VariableReferenceExpression a = p.variable("a");
                VariableReferenceExpression b = p.variable("b");
                VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
                return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr0"), p.rowExpression("count(a)")).addAggregation(p.variable("expr1"), p.rowExpression("count(b)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("IF(ds > '2021-07-01', 1)"), b, p.rowExpression("IF(ds > '2021-06-01', 2)")), (PlanNode)p.values(ds))));
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("count", (List<String>)ImmutableList.of((Object)"expr")), Optional.of("expr1"), PlanMatchPattern.functionCall("count", (List<String>)ImmutableList.of((Object)"expr_1"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr0"), (Object)new Symbol("greater_than"), (Object)new Symbol("expr1"), (Object)new Symbol("greater_than_0")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than or greater_than_0", PlanMatchPattern.project((Map<String, ExpressionMatcher>)new ImmutableMap.Builder().put((Object)"a", (Object)PlanMatchPattern.expression("IF(ds > '2021-07-01', 1)")).put((Object)"b", (Object)PlanMatchPattern.expression("IF(ds > '2021-06-01', 2)")).put((Object)"greater_than", (Object)PlanMatchPattern.expression("ds > '2021-07-01'")).put((Object)"expr", (Object)PlanMatchPattern.expression("1")).put((Object)"greater_than_0", (Object)PlanMatchPattern.expression("ds > '2021-06-01'")).put((Object)"expr_1", (Object)PlanMatchPattern.expression("2")).build(), PlanMatchPattern.values("ds")))));
        }
    }

    @Test
    public void testFireTwoAggregationsWithSharedInput() {
        for (String strategy : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", strategy).on(p -> {
                VariableReferenceExpression a = p.variable("a");
                VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
                VariableReferenceExpression column0 = p.variable("column0", (Type)BigintType.BIGINT);
                return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr0"), p.rowExpression("MIN(a)")).addAggregation(p.variable("expr1"), p.rowExpression("MAX(a)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("IF(ds > '2021-06-01', column0)")), (PlanNode)p.values(ds, column0))));
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("min", (List<String>)ImmutableList.of((Object)"column0_0")), Optional.of("expr1"), PlanMatchPattern.functionCall("max", (List<String>)ImmutableList.of((Object)"column0_0"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr0"), (Object)new Symbol("greater_than"), (Object)new Symbol("expr1"), (Object)new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project((Map<String, ExpressionMatcher>)new ImmutableMap.Builder().put((Object)"a", (Object)PlanMatchPattern.expression("IF(ds > '2021-06-01', column0)")).put((Object)"greater_than", (Object)PlanMatchPattern.expression("ds > '2021-06-01'")).put((Object)"column0_0", (Object)PlanMatchPattern.expression("column0")).build(), PlanMatchPattern.values("ds", "column0")))));
        }
    }

    @Test
    public void testFireForOneOfTwoAggregations() {
        for (String strategy : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", strategy).on(p -> {
                VariableReferenceExpression a = p.variable("a");
                VariableReferenceExpression b = p.variable("b");
                VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
                return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr0"), p.rowExpression("count(a)")).addAggregation(p.variable("expr1"), p.rowExpression("count(b)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("IF(ds > '2021-07-01', 1)"), b, p.rowExpression("ds")), (PlanNode)p.values(ds))));
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("count", (List<String>)ImmutableList.of((Object)"expr")), Optional.of("expr1"), PlanMatchPattern.functionCall("count", (List<String>)ImmutableList.of((Object)"b"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr0"), (Object)new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("true", PlanMatchPattern.project((Map<String, ExpressionMatcher>)new ImmutableMap.Builder().put((Object)"a", (Object)PlanMatchPattern.expression("IF(ds > '2021-07-01', 1)")).put((Object)"b", (Object)PlanMatchPattern.expression("ds")).put((Object)"greater_than", (Object)PlanMatchPattern.expression("ds > '2021-07-01'")).put((Object)"expr", (Object)PlanMatchPattern.expression("1")).build(), PlanMatchPattern.values("ds")))));
        }
    }

    @Test
    public void testArrayOffset() {
        for (String strategy : new String[]{"filter_with_if", "unwrap_if_safe", "unwrap_if"}) {
            this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", strategy).on(p -> {
                VariableReferenceExpression arrayColumn = p.variable("arrayColumn", (Type)new ArrayType((Type)BigintType.BIGINT));
                VariableReferenceExpression arrayElement = p.variable("arrayElement", (Type)BigintType.BIGINT);
                return p.aggregation(aggregationBuilder -> aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr0"), p.rowExpression("SUM(arrayElement)")).source((PlanNode)p.project(PlanBuilder.assignment(arrayElement, p.rowExpression("IF(CARDINALITY(arrayColumn) > 0, arrayColumn[1])")), (PlanNode)p.values(arrayColumn))));
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("SUM", (List<String>)ImmutableList.of((Object)"arrayElement"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr0"), (Object)new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project((Map<String, ExpressionMatcher>)new ImmutableMap.Builder().put((Object)"arrayElement", (Object)PlanMatchPattern.expression("IF(CARDINALITY(arrayColumn) > 0, arrayColumn[1])")).put((Object)"greater_than", (Object)PlanMatchPattern.expression("CARDINALITY(arrayColumn) > 0")).build(), PlanMatchPattern.values("arrayColumn")))));
        }
    }

    @Test
    public void testDivide() {
        for (String strategy : new String[]{"filter_with_if", "unwrap_if_safe", "unwrap_if"}) {
            this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", strategy).on(p -> {
                VariableReferenceExpression a = p.variable("a", (Type)BigintType.BIGINT);
                VariableReferenceExpression b = p.variable("b", (Type)BigintType.BIGINT);
                VariableReferenceExpression result = p.variable("result", (Type)BigintType.BIGINT);
                return p.aggregation(aggregationBuilder -> aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr0"), p.rowExpression("SUM(result)")).source((PlanNode)p.project(PlanBuilder.assignment(result, p.rowExpression("IF(b != 0, a / b)")), (PlanNode)p.values(a, b))));
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("SUM", (List<String>)ImmutableList.of((Object)"result"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr0"), (Object)new Symbol("not_equal")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("not_equal", PlanMatchPattern.project((Map<String, ExpressionMatcher>)new ImmutableMap.Builder().put((Object)"result", (Object)PlanMatchPattern.expression("IF(b != 0, a / b)")).put((Object)"not_equal", (Object)PlanMatchPattern.expression("b != 0")).build(), PlanMatchPattern.values("a", "b")))));
        }
        this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "unwrap_if").on(p -> {
            VariableReferenceExpression a = p.variable("a", (Type)BigintType.BIGINT);
            VariableReferenceExpression b = p.variable("b", (Type)BigintType.BIGINT);
            VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
            VariableReferenceExpression result = p.variable("result", (Type)BigintType.BIGINT);
            return p.aggregation(aggregationBuilder -> aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr0"), p.rowExpression("SUM(result)")).source((PlanNode)p.project(PlanBuilder.assignment(result, p.rowExpression("IF(ds > '2021-07-01', a / b)")), (PlanNode)p.values(ds, a, b))));
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("SUM", (List<String>)ImmutableList.of((Object)"result"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr0"), (Object)new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project((Map<String, ExpressionMatcher>)new ImmutableMap.Builder().put((Object)"result", (Object)PlanMatchPattern.expression("a / b")).put((Object)"greater_than", (Object)PlanMatchPattern.expression("ds > '2021-07-01'")).build(), PlanMatchPattern.values("ds", "a", "b")))));
    }

    @Test
    public void testUnwrapIfForOneOfTwoAggregations() {
        this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "unwrap_if").on(p -> {
            VariableReferenceExpression result0 = p.variable("result0", (Type)BigintType.BIGINT);
            VariableReferenceExpression result1 = p.variable("result1", (Type)BigintType.BIGINT);
            VariableReferenceExpression a = p.variable("a", (Type)BigintType.BIGINT);
            VariableReferenceExpression b = p.variable("b", (Type)BigintType.BIGINT);
            return p.aggregation(aggregationBuilder -> aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr0"), p.rowExpression("count(result0)")).addAggregation(p.variable("expr1"), p.rowExpression("count(result1)")).source((PlanNode)p.project(PlanBuilder.assignment(result0, p.rowExpression("IF(b != 0, a / b)"), result1, p.rowExpression("IF(b > 0, b)")), (PlanNode)p.values(a, b))));
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("count", (List<String>)ImmutableList.of((Object)"result0")), Optional.of("expr1"), PlanMatchPattern.functionCall("count", (List<String>)ImmutableList.of((Object)"b_0"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr0"), (Object)new Symbol("not_equal"), (Object)new Symbol("expr1"), (Object)new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than or not_equal", PlanMatchPattern.project((Map<String, ExpressionMatcher>)new ImmutableMap.Builder().put((Object)"result0", (Object)PlanMatchPattern.expression("IF(b != 0, a / b)")).put((Object)"result1", (Object)PlanMatchPattern.expression("IF(b > 0, b)")).put((Object)"b_0", (Object)PlanMatchPattern.expression("b")).put((Object)"not_equal", (Object)PlanMatchPattern.expression("b != 0")).put((Object)"greater_than", (Object)PlanMatchPattern.expression("b > 0")).build(), PlanMatchPattern.values("a", "b")))));
    }

    @Test
    public void testRewriteStrategies() {
        Function<PlanBuilder, PlanNode> planProvider = p -> {
            VariableReferenceExpression a = p.variable("a");
            VariableReferenceExpression column0 = p.variable("column0", (Type)BigintType.BIGINT);
            return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr0"), p.rowExpression("SUM(a)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("IF(column0 > 1, column0)")), (PlanNode)p.values(column0))));
        };
        this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "disabled").on(planProvider).doesNotFire();
        this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "filter_with_if").on(planProvider).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("sum", (List<String>)ImmutableList.of((Object)"a"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr0"), (Object)new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project((Map<String, ExpressionMatcher>)new ImmutableMap.Builder().put((Object)"a", (Object)PlanMatchPattern.expression("IF(column0 > 1, column0)")).put((Object)"greater_than", (Object)PlanMatchPattern.expression("column0 > 1")).build(), PlanMatchPattern.values("column0")))));
        this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "unwrap_if_safe").on(planProvider).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("sum", (List<String>)ImmutableList.of((Object)"a"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr0"), (Object)new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project((Map<String, ExpressionMatcher>)new ImmutableMap.Builder().put((Object)"a", (Object)PlanMatchPattern.expression("IF(column0 > 1, column0)")).put((Object)"greater_than", (Object)PlanMatchPattern.expression("column0 > 1")).build(), PlanMatchPattern.values("column0")))));
        this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "unwrap_if").on(planProvider).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("sum", (List<String>)ImmutableList.of((Object)"column0_0"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr0"), (Object)new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project((Map<String, ExpressionMatcher>)new ImmutableMap.Builder().put((Object)"a", (Object)PlanMatchPattern.expression("IF(column0 > 1, column0)")).put((Object)"greater_than", (Object)PlanMatchPattern.expression("column0 > 1")).put((Object)"column0_0", (Object)PlanMatchPattern.expression("column0")).build(), PlanMatchPattern.values("column0")))));
    }

    @Test
    public void testCast() {
        this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", "filter_with_if").on(p -> {
            VariableReferenceExpression a = p.variable("a");
            VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
            VariableReferenceExpression column0 = p.variable("column0", (Type)BigintType.BIGINT);
            return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr0"), p.rowExpression("SUM(a)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("CAST(IF(ds > '2021-06-01', column0) AS bigint)")), (PlanNode)p.values(ds, column0))));
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("sum", (List<String>)ImmutableList.of((Object)"a"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr0"), (Object)new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project((Map<String, ExpressionMatcher>)new ImmutableMap.Builder().put((Object)"a", (Object)PlanMatchPattern.expression("CAST(IF(ds > '2021-06-01', column0) as bigint)")).put((Object)"greater_than", (Object)PlanMatchPattern.expression("ds > '2021-06-01'")).build(), PlanMatchPattern.values("ds", "column0")))));
        for (String strategy : new String[]{"unwrap_if_safe", "unwrap_if"}) {
            this.tester().assertThat((Rule)new RewriteAggregationIfToFilter(this.getFunctionManager())).setSystemProperty("aggregation_if_to_filter_rewrite_strategy", strategy).on(p -> {
                VariableReferenceExpression a = p.variable("a");
                VariableReferenceExpression ds = p.variable("ds", (Type)VarcharType.VARCHAR);
                VariableReferenceExpression column0 = p.variable("column0", (Type)BigintType.BIGINT);
                return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(p.variable("expr0"), p.rowExpression("SUM(a)")).source((PlanNode)p.project(PlanBuilder.assignment(a, p.rowExpression("CAST(IF(ds > '2021-06-01', column0) AS bigint)")), (PlanNode)p.values(ds, column0))));
            }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("expr0"), PlanMatchPattern.functionCall("sum", (List<String>)ImmutableList.of((Object)"cast"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("expr0"), (Object)new Symbol("greater_than")), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.filter("greater_than", PlanMatchPattern.project((Map<String, ExpressionMatcher>)new ImmutableMap.Builder().put((Object)"a", (Object)PlanMatchPattern.expression("CAST(IF(ds > '2021-06-01', column0) as bigint)")).put((Object)"greater_than", (Object)PlanMatchPattern.expression("ds > '2021-06-01'")).put((Object)"cast", (Object)PlanMatchPattern.expression("CAST(column0 AS bigint)")).build(), PlanMatchPattern.values("ds", "column0")))));
        }
    }
}

