/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.Plugin;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.Type;
import io.trino.sql.ir.BooleanLiteral;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.ExpressionRewriter;
import io.trino.sql.ir.ExpressionTreeRewriter;
import io.trino.sql.ir.IsNullPredicate;
import io.trino.sql.ir.NotExpression;
import io.trino.sql.ir.Row;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.AggregationFunction;
import io.trino.sql.planner.assertions.ExpectedValueProvider;
import io.trino.sql.planner.assertions.ExpressionMatcher;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.ExpressionRewriteRuleSet;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.rowpattern.Patterns;
import io.trino.sql.planner.rowpattern.ir.IrRowPattern;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;

public class TestExpressionRewriteRuleSet
extends BaseRuleTest {
    private final ExpressionRewriteRuleSet zeroRewriter = new ExpressionRewriteRuleSet((expression, context) -> ExpressionTreeRewriter.rewriteWith((ExpressionRewriter)new ExpressionRewriter<Void>(this){

        protected Expression rewriteExpression(Expression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
            return new Constant((Type)IntegerType.INTEGER, (Object)0L);
        }

        public Expression rewriteRow(Row node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
            return new Row((List)node.getItems().stream().map(item -> new Constant((Type)IntegerType.INTEGER, (Object)0L)).collect(ImmutableList.toImmutableList()));
        }
    }, (Expression)expression));

    public TestExpressionRewriteRuleSet() {
        super(new Plugin[0]);
    }

    @Test
    public void testProjectionExpressionRewrite() {
        this.tester().assertThat(this.zeroRewriter.projectExpressionRewrite()).on(p -> p.project(Assignments.of((Symbol)p.symbol("y"), (Expression)new NotExpression((Expression)new IsNullPredicate((Expression)new SymbolReference((Type)BigintType.BIGINT, "x")))), (PlanNode)p.values(p.symbol("x")))).matches(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"y", (Object)PlanMatchPattern.expression((Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L))), PlanMatchPattern.values("x")));
    }

    @Test
    public void testProjectionExpressionNotRewritten() {
        this.tester().assertThat(this.zeroRewriter.projectExpressionRewrite()).on(p -> p.project(Assignments.of((Symbol)p.symbol("y"), (Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L)), (PlanNode)p.values(p.symbol("x")))).doesNotFire();
    }

    @Test
    public void testAggregationExpressionRewrite() {
        ExpressionRewriteRuleSet functionCallRewriter = new ExpressionRewriteRuleSet((expression, context) -> new SymbolReference((Type)BigintType.BIGINT, "y"));
        this.tester().assertThat(functionCallRewriter.aggregationExpressionRewrite()).on(p -> p.aggregation(a -> a.globalGrouping().addAggregation(p.symbol("count_1", (Type)BigintType.BIGINT), PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of((Object)new SymbolReference((Type)BigintType.BIGINT, "x"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT)).source((PlanNode)p.values(p.symbol("x"), p.symbol("y"))))).matches(PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of((Object)"count_1", PlanMatchPattern.aggregationFunction("count", (List<String>)ImmutableList.of((Object)"y"))), PlanMatchPattern.values("x", "y")));
    }

    @Test
    public void testFilterExpressionRewrite() {
        this.tester().assertThat(this.zeroRewriter.filterExpressionRewrite()).on(p -> p.filter((Expression)new Constant((Type)IntegerType.INTEGER, (Object)1L), (PlanNode)p.values(new Symbol[0]))).matches(PlanMatchPattern.filter((Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L), PlanMatchPattern.values(new String[0])));
    }

    @Test
    public void testFilterExpressionNotRewritten() {
        this.tester().assertThat(this.zeroRewriter.filterExpressionRewrite()).on(p -> p.filter((Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L), (PlanNode)p.values(new Symbol[0]))).doesNotFire();
    }

    @Test
    public void testValueExpressionRewrite() {
        this.tester().assertThat(this.zeroRewriter.valuesExpressionRewrite()).on(p -> p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("a")), (List<List<Expression>>)ImmutableList.of((Object)ImmutableList.of((Object)new Constant((Type)IntegerType.INTEGER, (Object)1L))))).matches(PlanMatchPattern.values((List<String>)ImmutableList.of((Object)"a"), (List<List<Expression>>)ImmutableList.of((Object)ImmutableList.of((Object)new Constant((Type)IntegerType.INTEGER, (Object)0L)))));
    }

    @Test
    public void testValueExpressionNotRewritten() {
        this.tester().assertThat(this.zeroRewriter.valuesExpressionRewrite()).on(p -> p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("a")), (List<List<Expression>>)ImmutableList.of((Object)ImmutableList.of((Object)new Constant((Type)IntegerType.INTEGER, (Object)0L))))).doesNotFire();
    }

    @Test
    public void testPatternRecognitionExpressionRewrite() {
        this.tester().assertThat(this.zeroRewriter.patternRecognitionExpressionRewrite()).on(p -> p.patternRecognition(builder -> builder.addMeasure(p.symbol("measure_1", (Type)IntegerType.INTEGER), (Expression)new Constant((Type)IntegerType.INTEGER, (Object)1L)).pattern((IrRowPattern)Patterns.label("X")).addVariableDefinition(Patterns.label("X"), (Expression)BooleanLiteral.TRUE_LITERAL).source((PlanNode)p.values(p.symbol("a", (Type)IntegerType.INTEGER))))).matches(PlanMatchPattern.patternRecognition(builder -> builder.addMeasure("measure_1", (Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L), (Type)IntegerType.INTEGER).pattern((IrRowPattern)Patterns.label("X")).addVariableDefinition(Patterns.label("X"), (Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L)), PlanMatchPattern.values("a")));
    }

    @Test
    public void testPatternRecognitionExpressionNotRewritten() {
        this.tester().assertThat(this.zeroRewriter.patternRecognitionExpressionRewrite()).on(p -> p.patternRecognition(builder -> builder.addMeasure(p.symbol("measure_1"), (Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L)).pattern((IrRowPattern)Patterns.label("X")).addVariableDefinition(Patterns.label("X"), (Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L)).source((PlanNode)p.values(p.symbol("a"))))).doesNotFire();
    }
}

