/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule.test;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.TableHandle;
import io.trino.plugin.tpch.TpchTableHandle;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.connector.ConnectorTransactionHandle;
import io.trino.spi.connector.TestingColumnHandle;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.Type;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.test.RuleAssert;
import io.trino.sql.planner.iterative.rule.test.RuleTester;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.testing.TestingTransactionHandle;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestRuleTester {
    @Test
    public void testReportWrongMatch() {
        try (RuleTester tester = RuleTester.defaultRuleTester();){
            RuleAssert ruleAssert = tester.assertThat(TestRuleTester.rule("testReportWrongMatch rule", Pattern.typeOf(PlanNode.class), (node, captures, context) -> Rule.Result.ofPlanNode((PlanNode)node.replaceChildren(node.getSources())))).on(p -> p.project(Assignments.of((Symbol)p.symbol("y"), (Expression)new SymbolReference((Type)IntegerType.INTEGER, "x")), (PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("x")), (List<List<Expression>>)ImmutableList.of((Object)ImmutableList.of((Object)new Constant((Type)IntegerType.INTEGER, (Object)1L))))));
            PlanMatchPattern expected = PlanMatchPattern.values((List<String>)ImmutableList.of((Object)"different"), (List<List<Expression>>)ImmutableList.of());
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> ruleAssert.matches(expected)).isInstanceOf(AssertionError.class)).hasMessageMatching("(?s)Plan does not match, expected .* but found .*");
        }
    }

    @Test
    public void testReportNoFire() {
        try (RuleTester tester = RuleTester.defaultRuleTester();){
            RuleAssert ruleAssert = tester.assertThat(TestRuleTester.rule("testReportNoFire rule", Pattern.typeOf(PlanNode.class), (node, captures, context) -> Rule.Result.empty())).on(p -> p.values(List.of(p.symbol("x")), List.of(List.of(new Constant((Type)IntegerType.INTEGER, (Object)1L)))));
            PlanMatchPattern expected = PlanMatchPattern.values(List.of("whatever"), List.of());
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> ruleAssert.matches(expected)).isInstanceOf(AssertionError.class)).hasMessageMatching("testReportNoFire rule did not fire for:(?s:.*)");
        }
    }

    @Test
    public void testReportNoFireWithTableScan() {
        try (RuleTester tester = RuleTester.defaultRuleTester();){
            RuleAssert ruleAssert = tester.assertThat(TestRuleTester.rule("testReportNoFireWithTableScan rule", Pattern.typeOf(PlanNode.class), (node, captures, context) -> Rule.Result.empty())).on(p -> p.tableScan(new TableHandle(tester.getCurrentCatalogHandle(), (ConnectorTableHandle)new TpchTableHandle("sf1", "nation", 1.0), (ConnectorTransactionHandle)TestingTransactionHandle.create()), List.of(p.symbol("x")), Map.of(p.symbol("x"), new TestingColumnHandle("column"))));
            PlanMatchPattern expected = PlanMatchPattern.values(List.of("whatever"), List.of());
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> ruleAssert.matches(expected)).isInstanceOf(AssertionError.class)).hasMessageMatching("testReportNoFireWithTableScan rule did not fire for:\n(?s:.*)\\QEstimates: {rows: 25 (225B), cpu: 225, memory: 0B, network: 0B}\\E\n(?s:.*)");
        }
    }

    private static <T> Rule<T> rule(final String name, final Pattern<T> pattern, final RuleApplyImplementation<T> apply) {
        Objects.requireNonNull(name, "name is null");
        Objects.requireNonNull(pattern, "pattern is null");
        Objects.requireNonNull(apply, "apply is null");
        return new Rule<T>(){

            public String toString() {
                return name;
            }

            public Pattern<T> getPattern() {
                return pattern;
            }

            public Rule.Result apply(T node, Captures captures, Rule.Context context) {
                return apply.apply(node, captures, context);
            }
        };
    }

    @FunctionalInterface
    private static interface RuleApplyImplementation<T> {
        public Rule.Result apply(T var1, Captures var2, Rule.Context var3);
    }
}

