/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.spi.Plugin;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.DateType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.VarcharType;
import io.prestosql.sql.planner.FunctionCallBuilder;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.assertions.ExpectedValueProvider;
import io.prestosql.sql.planner.assertions.ExpressionMatcher;
import io.prestosql.sql.planner.assertions.PlanMatchPattern;
import io.prestosql.sql.planner.iterative.rule.ExpressionRewriteRuleSet;
import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest;
import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.ExpressionRewriter;
import io.prestosql.sql.tree.ExpressionTreeRewriter;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.LongLiteral;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.sql.tree.Row;
import io.prestosql.sql.tree.SymbolReference;
import java.util.List;
import java.util.Map;
import org.testng.annotations.Test;

public class TestExpressionRewriteRuleSet
extends BaseRuleTest {
    private ExpressionRewriteRuleSet zeroRewriter = new ExpressionRewriteRuleSet((expression, context) -> ExpressionTreeRewriter.rewriteWith((ExpressionRewriter)new ExpressionRewriter<Void>(){

        protected Expression rewriteExpression(Expression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
            return new LongLiteral("0");
        }

        public Expression rewriteRow(Row node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
            return new Row((List)node.getItems().stream().map(item -> new LongLiteral("0")).collect(ImmutableList.toImmutableList()));
        }
    }, (Expression)expression));

    public TestExpressionRewriteRuleSet() {
        super(new Plugin[0]);
    }

    @Test
    public void testProjectionExpressionRewrite() {
        this.tester().assertThat(this.zeroRewriter.projectExpressionRewrite()).on(p -> p.project(Assignments.of((Symbol)p.symbol("y"), (Expression)PlanBuilder.expression("x IS NOT NULL")), (PlanNode)p.values(p.symbol("x")))).matches(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"y", (Object)PlanMatchPattern.expression("0")), PlanMatchPattern.values("x")));
    }

    @Test
    public void testProjectionExpressionNotRewritten() {
        this.tester().assertThat(this.zeroRewriter.projectExpressionRewrite()).on(p -> p.project(Assignments.of((Symbol)p.symbol("y"), (Expression)PlanBuilder.expression("0")), (PlanNode)p.values(p.symbol("x")))).doesNotFire();
    }

    @Test
    public void testAggregationExpressionRewrite() {
        ExpressionRewriteRuleSet functionCallRewriter = new ExpressionRewriteRuleSet((expression, context) -> new FunctionCallBuilder(this.tester().getMetadata()).setName(QualifiedName.of((String)"count")).addArgument((Type)VarcharType.VARCHAR, (Expression)new SymbolReference("y")).build());
        this.tester().assertThat(functionCallRewriter.aggregationExpressionRewrite()).on(p -> p.aggregation(a -> a.globalGrouping().addAggregation(p.symbol("count_1", (Type)BigintType.BIGINT), (Expression)new FunctionCallBuilder(this.tester().getMetadata()).setName(QualifiedName.of((String)"count")).addArgument((Type)VarcharType.VARCHAR, (Expression)new SymbolReference("x")).build(), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).source((PlanNode)p.values(p.symbol("x"), p.symbol("y"))))).matches(PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of((Object)"count_1", aliases -> new FunctionCallBuilder(this.tester().getMetadata()).setName(QualifiedName.of((String)"count")).addArgument((Type)VarcharType.VARCHAR, (Expression)new SymbolReference("y")).build()), PlanMatchPattern.values("x", "y")));
    }

    @Test
    public void testAggregationExpressionNotRewritten() {
        FunctionCall nowCall = new FunctionCallBuilder(this.tester().getMetadata()).setName(QualifiedName.of((String)"now")).build();
        ExpressionRewriteRuleSet functionCallRewriter = new ExpressionRewriteRuleSet((expression, context) -> nowCall);
        this.tester().assertThat(functionCallRewriter.aggregationExpressionRewrite()).on(p -> p.aggregation(a -> a.globalGrouping().addAggregation(p.symbol("count_1", (Type)DateType.DATE), (Expression)nowCall, (List<Type>)ImmutableList.of()).source((PlanNode)p.values(new Symbol[0])))).doesNotFire();
    }

    @Test
    public void testFilterExpressionRewrite() {
        this.tester().assertThat(this.zeroRewriter.filterExpressionRewrite()).on(p -> p.filter((Expression)new LongLiteral("1"), (PlanNode)p.values(new Symbol[0]))).matches(PlanMatchPattern.filter("0", PlanMatchPattern.values(new String[0])));
    }

    @Test
    public void testFilterExpressionNotRewritten() {
        this.tester().assertThat(this.zeroRewriter.filterExpressionRewrite()).on(p -> p.filter((Expression)new LongLiteral("0"), (PlanNode)p.values(new Symbol[0]))).doesNotFire();
    }

    @Test
    public void testValueExpressionRewrite() {
        this.tester().assertThat(this.zeroRewriter.valuesExpressionRewrite()).on(p -> p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("a")), (List<List<Expression>>)ImmutableList.of((Object)ImmutableList.of((Object)PlanBuilder.expression("1"))))).matches(PlanMatchPattern.values((List<String>)ImmutableList.of((Object)"a"), (List<List<Expression>>)ImmutableList.of((Object)ImmutableList.of((Object)new LongLiteral("0")))));
    }

    @Test
    public void testValueExpressionNotRewritten() {
        this.tester().assertThat(this.zeroRewriter.valuesExpressionRewrite()).on(p -> p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("a")), (List<List<Expression>>)ImmutableList.of((Object)ImmutableList.of((Object)PlanBuilder.expression("0"))))).doesNotFire();
    }
}

