/*
 * Decompiled with CFR 0.152.
 */
package io.trino.tests.product.hive;

import com.google.inject.Inject;
import io.trino.tempto.ProductTest;
import io.trino.tempto.Requirement;
import io.trino.tempto.RequirementsProvider;
import io.trino.tempto.assertions.QueryAssert;
import io.trino.tempto.configuration.Configuration;
import io.trino.tempto.fulfillment.table.MutableTableRequirement;
import io.trino.tempto.fulfillment.table.MutableTablesState;
import io.trino.tempto.fulfillment.table.TableDefinition;
import io.trino.tempto.fulfillment.table.TableRequirements;
import io.trino.tempto.query.QueryExecutor;
import io.trino.tests.product.hive.HiveTableDefinitions;
import io.trino.tests.product.utils.QueryExecutors;
import org.assertj.core.api.AssertProvider;
import org.assertj.core.api.Assertions;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

public class TestHiveRequireQueryPartitionsFilter
extends ProductTest
implements RequirementsProvider {
    @Inject
    private MutableTablesState tablesState;

    public Requirement getRequirements(Configuration configuration) {
        return TableRequirements.mutableTable((TableDefinition)HiveTableDefinitions.NATION_PARTITIONED_BY_BIGINT_REGIONKEY, (String)"test_table", (MutableTableRequirement.State)MutableTableRequirement.State.LOADED);
    }

    @Test
    public void testRequiresQueryPartitionFilter() {
        String tableName = this.tablesState.get("test_table").getNameInDatabase();
        QueryExecutors.onTrino().executeQuery("SET SESSION hive.query_partition_filter_required = true", new QueryExecutor.QueryParam[0]);
        QueryAssert.assertQueryFailure(() -> QueryExecutors.onTrino().executeQuery("SELECT COUNT(*) FROM " + tableName, new QueryExecutor.QueryParam[0])).hasMessageMatching(String.format("Query failed \\(#\\w+\\): Filter required on default\\.%s for at least one partition column: p_regionkey", tableName));
        ((QueryAssert)Assertions.assertThat((AssertProvider)QueryExecutors.onTrino().executeQuery(String.format("SELECT COUNT(*) FROM %s WHERE p_regionkey = 1", tableName), new QueryExecutor.QueryParam[0]))).containsOnly(new QueryAssert.Row[]{QueryAssert.Row.row((Object[])new Object[]{5})});
    }

    @Test(dataProvider="queryPartitionFilterRequiredSchemasDataProvider")
    public void testRequiresQueryPartitionFilterOnSpecificSchema(String queryPartitionFilterRequiredSchemas) {
        String tableName = this.tablesState.get("test_table").getNameInDatabase();
        QueryExecutors.onTrino().executeQuery("SET SESSION hive.query_partition_filter_required = true", new QueryExecutor.QueryParam[0]);
        QueryExecutors.onTrino().executeQuery(String.format("SET SESSION hive.query_partition_filter_required_schemas = %s", queryPartitionFilterRequiredSchemas), new QueryExecutor.QueryParam[0]);
        QueryAssert.assertQueryFailure(() -> QueryExecutors.onTrino().executeQuery("SELECT COUNT(*) FROM " + tableName, new QueryExecutor.QueryParam[0])).hasMessageMatching(String.format("Query failed \\(#\\w+\\): Filter required on default\\.%s for at least one partition column: p_regionkey", tableName));
        ((QueryAssert)Assertions.assertThat((AssertProvider)QueryExecutors.onTrino().executeQuery(String.format("SELECT COUNT(*) FROM %s WHERE p_regionkey = 1", tableName), new QueryExecutor.QueryParam[0]))).containsOnly(new QueryAssert.Row[]{QueryAssert.Row.row((Object[])new Object[]{5})});
    }

    @DataProvider
    public Object[][] queryPartitionFilterRequiredSchemasDataProvider() {
        return new Object[][]{{"ARRAY['default']"}, {"ARRAY['DEFAULT']"}, {"ARRAY['deFAUlt']"}};
    }
}

