/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.query;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.connector.MockConnectorFactory;
import io.trino.metadata.QualifiedObjectName;
import io.trino.plugin.tpch.TpchConnectorFactory;
import io.trino.spi.connector.ConnectorFactory;
import io.trino.spi.connector.ConnectorViewDefinition;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.security.Identity;
import io.trino.spi.security.ViewExpression;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.VarcharType;
import io.trino.sql.query.QueryAssertions;
import io.trino.testing.LocalQueryRunner;
import io.trino.testing.QueryRunner;
import io.trino.testing.TestingAccessControlManager;
import io.trino.testing.TestingSession;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.assertj.core.api.Assertions;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

@Test(singleThreaded=true)
public class TestRowFilter {
    private static final String CATALOG = "local";
    private static final String MOCK_CATALOG = "mock";
    private static final String USER = "user";
    private static final String VIEW_OWNER = "view-owner";
    private static final String RUN_AS_USER = "run-as-user";
    private static final Session SESSION = TestingSession.testSessionBuilder().setCatalog("local").setSchema("tiny").setIdentity(Identity.forUser((String)"user").build()).build();
    private QueryAssertions assertions;
    private TestingAccessControlManager accessControl;

    @BeforeClass
    public void init() {
        LocalQueryRunner runner = LocalQueryRunner.builder((Session)SESSION).build();
        runner.createCatalog(CATALOG, (ConnectorFactory)new TpchConnectorFactory(1), (Map)ImmutableMap.of());
        ConnectorViewDefinition view = new ConnectorViewDefinition("SELECT nationkey, name FROM local.tiny.nation", Optional.empty(), Optional.empty(), (List)ImmutableList.of((Object)new ConnectorViewDefinition.ViewColumn("nationkey", BigintType.BIGINT.getTypeId()), (Object)new ConnectorViewDefinition.ViewColumn("name", VarcharType.createVarcharType((int)25).getTypeId())), Optional.empty(), Optional.of(VIEW_OWNER), false);
        MockConnectorFactory mock = MockConnectorFactory.builder().withGetViews((s, prefix) -> ImmutableMap.builder().put((Object)new SchemaTableName("default", "nation_view"), (Object)view).build()).build();
        runner.createCatalog(MOCK_CATALOG, (ConnectorFactory)mock, (Map)ImmutableMap.of());
        this.assertions = new QueryAssertions((QueryRunner)runner);
        this.accessControl = this.assertions.getQueryRunner().getAccessControl();
    }

    @AfterClass(alwaysRun=true)
    public void teardown() {
        this.assertions.close();
        this.assertions = null;
    }

    @Test
    public void testSimpleFilter() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey < 10"));
        ((QueryAssertions.QueryAssert)((Object)Assertions.assertThat(this.assertions.query("SELECT count(*) FROM orders")))).matches("VALUES BIGINT '7'");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(USER, Optional.empty(), Optional.empty(), "NULL"));
        ((QueryAssertions.QueryAssert)((Object)Assertions.assertThat(this.assertions.query("SELECT count(*) FROM orders")))).matches("VALUES BIGINT '0'");
    }

    @Test
    public void testMultipleFilters() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey < 10"));
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey > 5"));
        ((QueryAssertions.QueryAssert)((Object)Assertions.assertThat(this.assertions.query("SELECT count(*) FROM orders")))).matches("VALUES BIGINT '2'");
    }

    @Test
    public void testCorrelatedSubquery() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "EXISTS (SELECT 1 FROM nation WHERE nationkey = orderkey)"));
        ((QueryAssertions.QueryAssert)((Object)Assertions.assertThat(this.assertions.query("SELECT count(*) FROM orders")))).matches("VALUES BIGINT '7'");
    }

    @Test
    public void testView() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "nation"), VIEW_OWNER, new ViewExpression(VIEW_OWNER, Optional.empty(), Optional.empty(), "nationkey = 1"));
        ((QueryAssertions.QueryAssert)((Object)Assertions.assertThat(this.assertions.query(Session.builder((Session)SESSION).setIdentity(Identity.forUser((String)RUN_AS_USER).build()).build(), "SELECT name FROM mock.default.nation_view")))).matches("VALUES CAST('ARGENTINA' AS VARCHAR(25))");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "nation"), VIEW_OWNER, new ViewExpression(VIEW_OWNER, Optional.of(CATALOG), Optional.of("tiny"), "nationkey = 1"));
        ((QueryAssertions.QueryAssert)((Object)Assertions.assertThat(this.assertions.query(Session.builder((Session)SESSION).setIdentity(Identity.forUser((String)VIEW_OWNER).build()).build(), "SELECT name FROM mock.default.nation_view")))).matches("VALUES CAST('ARGENTINA' AS VARCHAR(25))");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "nation"), RUN_AS_USER, new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "nationkey = 1"));
        Session session = Session.builder((Session)SESSION).setIdentity(Identity.forUser((String)RUN_AS_USER).build()).build();
        ((QueryAssertions.QueryAssert)((Object)Assertions.assertThat(this.assertions.query(session, "SELECT count(*) FROM mock.default.nation_view")))).matches("VALUES BIGINT '25'");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(MOCK_CATALOG, "default", "nation_view"), USER, new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "nationkey = 1"));
        ((QueryAssertions.QueryAssert)((Object)Assertions.assertThat(this.assertions.query("SELECT name FROM mock.default.nation_view")))).matches("VALUES CAST('ARGENTINA' AS VARCHAR(25))");
    }

    @Test
    public void testTableReferenceInWithClause() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey = 1"));
        ((QueryAssertions.QueryAssert)((Object)Assertions.assertThat(this.assertions.query("WITH t AS (SELECT count(*) FROM orders) SELECT * FROM t")))).matches("VALUES BIGINT '1'");
    }

    @Test
    public void testOtherSchema() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(USER, Optional.of(CATALOG), Optional.of("sf1"), "(SELECT count(*) FROM customer) = 150000"));
        ((QueryAssertions.QueryAssert)((Object)Assertions.assertThat(this.assertions.query("SELECT count(*) FROM orders")))).matches("VALUES BIGINT '15000'");
    }

    @Test
    public void testDifferentIdentity() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), RUN_AS_USER, new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey = 1"));
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)"));
        ((QueryAssertions.QueryAssert)((Object)Assertions.assertThat(this.assertions.query("SELECT count(*) FROM orders")))).matches("VALUES BIGINT '1'");
    }

    @Test
    public void testRecursion() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)"));
        Assertions.assertThatThrownBy(() -> this.assertions.query("SELECT count(*) FROM orders")).hasMessageMatching(".*\\QRow filter for 'local.tiny.orders' is recursive\\E.*");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey IN (SELECT local.tiny.orderkey FROM orders)"));
        Assertions.assertThatThrownBy(() -> this.assertions.query("SELECT count(*) FROM orders")).hasMessageMatching(".*\\QRow filter for 'local.tiny.orders' is recursive\\E.*");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), RUN_AS_USER, new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)"));
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)"));
        Assertions.assertThatThrownBy(() -> this.assertions.query("SELECT count(*) FROM orders")).hasMessageMatching(".*\\QRow filter for 'local.tiny.orders' is recursive\\E.*");
    }

    @Test
    public void testLimitedScope() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "customer"), USER, new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey = 1"));
        Assertions.assertThatThrownBy(() -> this.assertions.query("SELECT (SELECT min(name) FROM customer WHERE customer.custkey = orders.custkey) FROM orders")).hasMessage("line 1:31: Invalid row filter for 'local.tiny.customer': Column 'orderkey' cannot be resolved");
    }

    @Test
    public void testSqlInjection() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "nation"), USER, new ViewExpression(USER, Optional.of(CATALOG), Optional.of("tiny"), "regionkey IN (SELECT regionkey FROM region WHERE name = 'ASIA')"));
        ((QueryAssertions.QueryAssert)((Object)Assertions.assertThat(this.assertions.query("WITH region(regionkey, name) AS (VALUES (0, 'ASIA'), (1, 'ASIA'), (2, 'ASIA'), (3, 'ASIA'), (4, 'ASIA'))SELECT name FROM nation ORDER BY name LIMIT 1")))).matches("VALUES CAST('CHINA' AS VARCHAR(25))");
    }

    @Test
    public void testInvalidFilter() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "$$$"));
        Assertions.assertThatThrownBy(() -> this.assertions.query("SELECT count(*) FROM orders")).hasMessage("line 1:22: Invalid row filter for 'local.tiny.orders': mismatched input '$'. Expecting: <expression>");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "unknown_column"));
        Assertions.assertThatThrownBy(() -> this.assertions.query("SELECT count(*) FROM orders")).hasMessage("line 1:22: Invalid row filter for 'local.tiny.orders': Column 'unknown_column' cannot be resolved");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "1"));
        Assertions.assertThatThrownBy(() -> this.assertions.query("SELECT count(*) FROM orders")).hasMessage("line 1:22: Expected row filter for 'local.tiny.orders' to be of type BOOLEAN, but was integer");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "count(*) > 0"));
        Assertions.assertThatThrownBy(() -> this.assertions.query("SELECT count(*) FROM orders")).hasMessage("line 1:10: Row filter for 'local.tiny.orders' cannot contain aggregations, window functions or grouping operations: [count(*)]");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "row_number() OVER () > 0"));
        Assertions.assertThatThrownBy(() -> this.assertions.query("SELECT count(*) FROM orders")).hasMessage("line 1:22: Row filter for 'local.tiny.orders' cannot contain aggregations, window functions or grouping operations: [row_number() OVER ()]");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "grouping(orderkey) = 0"));
        Assertions.assertThatThrownBy(() -> this.assertions.query("SELECT count(*) FROM orders")).hasMessage("line 1:20: Row filter for 'local.tiny.orders' cannot contain aggregations, window functions or grouping operations: [GROUPING (orderkey)]");
    }

    @Test
    public void testShowStats() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of(CATALOG), Optional.of("tiny"), "orderkey = 0"));
        ((QueryAssertions.QueryAssert)((Object)Assertions.assertThat(this.assertions.query("SHOW STATS FOR (SELECT * FROM tiny.orders)")))).containsAll("VALUES (VARCHAR 'orderkey', 0e1, 0e1, 1e0, CAST(NULL AS double), CAST(NULL AS varchar), CAST(NULL AS varchar)),(VARCHAR 'custkey', 0e1, 0e1, 1e0, CAST(NULL AS double), CAST(NULL AS varchar), CAST(NULL AS varchar)),(NULL, NULL, NULL, NULL, 0e1, NULL, NULL)");
    }

    @Test
    public void testDelete() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey < 10"));
        Assertions.assertThatThrownBy(() -> this.assertions.query("DELETE FROM orders")).hasMessage("line 1:1: Delete from table with row filter");
        Assertions.assertThatThrownBy(() -> this.assertions.query("DELETE FROM orders WHERE orderkey IN (1, 2, 3)")).hasMessage("line 1:1: Delete from table with row filter");
        Assertions.assertThatThrownBy(() -> this.assertions.query("DELETE FROM orders WHERE orderkey IN (10, 20, 30)")).hasMessage("line 1:1: Delete from table with row filter");
    }

    @Test
    public void testUpdate() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "orders"), USER, new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey < 10"));
        Assertions.assertThatThrownBy(() -> this.assertions.query("UPDATE orders SET totalprice = totalprice * 2")).hasMessage("line 1:1: Updating a table with a row filter is not supported");
        Assertions.assertThatThrownBy(() -> this.assertions.query("UPDATE orders SET totalprice = totalprice * 2 WHERE orderkey IN (1, 2, 3)")).hasMessage("line 1:1: Updating a table with a row filter is not supported");
        Assertions.assertThatThrownBy(() -> this.assertions.query("UPDATE orders SET totalprice = totalprice * 2 WHERE orderkey IN (10, 20, 30)")).hasMessage("line 1:1: Updating a table with a row filter is not supported");
    }

    @Test
    public void testInsert() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(CATALOG, "tiny", "nation"), USER, new ViewExpression(USER, Optional.empty(), Optional.empty(), "nationkey < 10"));
        Assertions.assertThatThrownBy(() -> this.assertions.query("INSERT INTO nation VALUES (26, 'POLAND', 0, 'No comment')")).hasMessage("Insert into table with a row filter is not supported");
    }
}

