/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.deltalake;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.MoreCollectors;
import com.google.common.io.Resources;
import io.trino.execution.QueryStats;
import io.trino.operator.OperatorStats;
import io.trino.plugin.deltalake.DeltaLakeQueryRunner;
import io.trino.testing.AbstractTestQueryFramework;
import io.trino.testing.MaterializedResult;
import io.trino.testing.QueryRunner;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;
import org.assertj.core.api.Assertions;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

@TestInstance(value=TestInstance.Lifecycle.PER_CLASS)
public class TestSplitPruning
extends AbstractTestQueryFramework {
    private static final List<String> TABLES = ImmutableList.of((Object)"double_inf", (Object)"double_nan", (Object)"part", (Object)"float_nan", (Object)"float_inf", (Object)"no_stats", (Object)"nested_fields", (Object)"timestamp", (Object)"test_partitioning", (Object)"parquet_struct_statistics", (Object)"uppercase_columns_partitions", (Object)"uppercase_columns_json_statistics", (Object[])new String[]{"uppercase_columns_struct_statistics"});

    protected QueryRunner createQueryRunner() throws Exception {
        return DeltaLakeQueryRunner.builder().addDeltaProperty("delta.register-table-procedure.enabled", "true").build();
    }

    @BeforeAll
    public void registerTables() {
        for (String table : TABLES) {
            String dataPath = Resources.getResource((String)("databricks73/pruning/" + table)).toExternalForm();
            this.getQueryRunner().execute(String.format("CALL system.register_table(CURRENT_SCHEMA, '%s', '%s')", table, dataPath));
        }
    }

    @Test
    public void testStatsPruningInfinity() {
        for (String type : Arrays.asList("float", "double")) {
            String tableName = type + "_inf";
            this.assertResultAndSplitCount(String.format("SELECT name FROM %s WHERE val < 200", tableName), Set.of("a1", "b1", "a3", "b3"), 2L);
            this.assertResultAndSplitCount(String.format("SELECT name FROM %s WHERE val > 100", tableName), Set.of("a2", "b2", "b3", "d3"), 2L);
            this.assertResultAndSplitCount(String.format("SELECT name FROM %s WHERE val IS NULL", tableName), Set.of("c3", "a4"), 2L);
        }
    }

    @Test
    public void testStatsPruningNaN() {
        for (String type : Arrays.asList("float", "double")) {
            String tableName = type + "_nan";
            this.assertResultAndSplitCount(String.format("SELECT name FROM %s WHERE val < 100", tableName), Set.of(), 2L);
            this.assertResultAndSplitCount(String.format("SELECT name FROM %s WHERE val IS NULL", tableName), Set.of(), 0L);
            MaterializedResult result = this.getDistributedQueryRunner().execute(this.getSession(), String.format("SELECT name FROM %s WHERE val IS NOT NULL", tableName));
            Assertions.assertThat((Collection)result.getOnlyColumnAsSet()).isEqualTo(Set.of("a5", "b5", "a6", "b6"));
        }
    }

    @Test
    public void testNoStats() {
        String tableName = "no_stats";
        this.assertResultAndSplitCount(String.format("SELECT name FROM %s WHERE val < 200", tableName), Set.of("a1", "b1", "a3", "b3"), 4L);
        this.assertResultAndSplitCount(String.format("SELECT name FROM %s WHERE val > 100", tableName), Set.of("a2", "b2", "b3", "d3"), 4L);
        this.assertResultAndSplitCount(String.format("SELECT name FROM %s WHERE val IS NULL", tableName), Set.of("c3", "a4"), 4L);
    }

    @Test
    public void testPruningUsingPartitions() {
        String tableName = "part";
        this.assertResultAndSplitCount(String.format("SELECT name FROM %s WHERE part_key = 7", tableName), Set.of("a7"), 1L);
        this.assertResultAndSplitCount(String.format("SELECT name FROM %s WHERE part_key IS NOT NULL", tableName), Set.of("a7", "-Infinity", "+Infinity", "NaN"), 4L);
        this.assertResultAndSplitCount(String.format("SELECT name FROM %s WHERE part_key IS NULL", tableName), Set.of("null"), 1L);
        this.assertResultAndSplitCount(String.format("SELECT name FROM %s WHERE part_key > 0", tableName), Set.of("a7", "+Infinity"), 2L);
        this.assertResultAndSplitCount(String.format("SELECT name FROM %s WHERE part_key < 10", tableName), Set.of("a7", "-Infinity"), 2L);
    }

    @Test
    public void testPruningUsingPartitionsUppercase() {
        String tableName = "uppercase_columns_partitions";
        this.assertResultAndSplitCount(String.format("SELECT ala FROM %s WHERE ala > 0", tableName), (MaterializedResult result) -> {
            Assertions.assertThat((Collection)result.getOnlyColumnAsSet()).containsOnly(new Object[]{1L, 2L, 3L});
            Assertions.assertThat((int)result.getRowCount()).isEqualTo(5);
        }, 3L);
        this.assertResultAndSplitCount(String.format("SELECT ala FROM %s WHERE ala = 1", tableName), (MaterializedResult result) -> {
            Assertions.assertThat((Collection)result.getOnlyColumnAsSet()).containsOnly(new Object[]{1L});
            Assertions.assertThat((int)result.getRowCount()).isEqualTo(2);
        }, 1L);
        this.assertResultAndSplitCount(String.format("SELECT ala FROM %s WHERE ala > 1", tableName), (MaterializedResult result) -> {
            Assertions.assertThat((Collection)result.getOnlyColumnAsSet()).containsOnly(new Object[]{2L, 3L});
            Assertions.assertThat((int)result.getRowCount()).isEqualTo(3);
        }, 2L);
        this.assertResultAndSplitCount(String.format("SELECT kota FROM %s WHERE ala = 1", tableName), (MaterializedResult result) -> {
            Assertions.assertThat((Collection)result.getOnlyColumnAsSet()).containsOnly(new Object[]{1L, 2L});
            Assertions.assertThat((int)result.getRowCount()).isEqualTo(2);
        }, 1L);
    }

    @Test
    public void testPartitionPruningWithExpression() {
        this.assertResultAndSplitCount("SELECT id FROM test_partitioning WHERE t_varchar LIKE '%a%'", Set.of(Integer.valueOf(1)), 1L);
    }

    @Test
    public void testPartitionPruningWithExpressionAndDomainFilter() {
        this.assertResultAndSplitCount("SELECT id FROM test_partitioning WHERE t_varchar LIKE '%a%' AND id > 0", Set.of(Integer.valueOf(1)), 1L);
    }

    @Test
    public void testSplitGenerationError() {
        String dataPath = Resources.getResource((String)"databricks73/pruning/invalid_log").toExternalForm();
        this.getQueryRunner().execute(String.format("CALL system.register_table(CURRENT_SCHEMA, 'person', '%s')", dataPath));
        this.assertQueryFails("SELECT name FROM person WHERE income < 1000", "Failed to generate splits for tpch.person");
    }

    @Test
    public void testTimestampPruning() {
        String tableName = "timestamp";
        this.assertResultAndSplitCount(String.format("SELECT col_2 FROM %s WHERE col_0 = 'UTC' AND col_1 = CAST('1952-04-03 01:02:03.456 UTC' AS TIMESTAMP WITH TIME ZONE)", tableName), Set.of("1952-04-03 01:02:03.456789"), 1L);
        this.assertResultAndSplitCount(String.format("SELECT col_2 FROM %s WHERE col_0 = 'UTC' AND col_1 > CAST('1996-10-27 00:05:00.987 UTC' AS TIMESTAMP WITH TIME ZONE) AND col_1 < CAST('1996-10-27 02:05:00.987 UTC' AS TIMESTAMP WITH TIME ZONE)", tableName), Set.of("1996-10-27 01:05:00.987"), 1L);
        this.assertResultAndSplitCount(String.format("SELECT col_2 FROM %s WHERE col_0 = 'UTC' AND col_1 = ANY (VALUES CAST('1900-01-01 UTC' AS TIMESTAMP WITH TIME ZONE), CAST('1983-04-01 01:05:00.345 UTC' AS TIMESTAMP WITH TIME ZONE), CAST('1996-10-27 02:05:00.987 UTC' AS TIMESTAMP WITH TIME ZONE))", tableName), Set.of("1900-01-01 00:00:00.000", "1983-04-01 01:05:00.3456789", "1996-10-27 02:05:00.987"), 3L);
        this.assertResultAndSplitCount(String.format("SELECT col_2 FROM %s WHERE col_0 = 'UTC' AND col_1 BETWEEN CAST('1952-04-03 UTC' AS TIMESTAMP WITH TIME ZONE) AND CAST('1970-02-04 UTC' AS TIMESTAMP WITH TIME ZONE) AND col_3 >= CAST('1970-01-01 UTC' AS TIMESTAMP WITH TIME ZONE)", tableName), Set.of("1970-01-01 01:05:00.123456789", "1970-01-01 00:05:00.123456789", "1970-01-01 00:00:00.000", "1970-02-03 04:05:06.789"), 4L);
        this.assertResultAndSplitCount(String.format("SELECT col_2 FROM %s WHERE col_2 LIKE '2%%' AND col_3 > CAST('1999-12-31 UTC' AS TIMESTAMP WITH TIME ZONE)", tableName), Set.of("2017-07-01 00:00:00.000"), 1L);
        this.assertResultAndSplitCount(String.format("SELECT col_2 FROM %s WHERE col_2 > '1999'", tableName), Set.of("2017-07-01 00:00:00.000", "9999-12-31 23:59:59.999999999"), 2L);
    }

    @Test
    public void testParquetStatisticsPruning() {
        this.testCountQuery("SELECT count(*) FROM parquet_struct_statistics WHERE ts = TIMESTAMP '2960-10-31 01:00:00.000 UTC'", 3L, 3L);
        this.testCountQuery("SELECT count(*) FROM parquet_struct_statistics WHERE ts = TIMESTAMP '2960-10-31 01:00:00.000 UTC'", 3L, 3L);
        this.testCountQuery("SELECT count(*) FROM parquet_struct_statistics WHERE str = 'a'", 3L, 3L);
        this.testCountQuery("SELECT count(*) FROM parquet_struct_statistics WHERE dec_short = 10.1", 3L, 3L);
        this.testCountQuery("SELECT count(*) FROM parquet_struct_statistics WHERE dec_long = -999999999999.123", 3L, 3L);
        this.testCountQuery("SELECT count(*) FROM parquet_struct_statistics WHERE l = 0", 3L, 3L);
        this.testCountQuery("SELECT count(*) FROM parquet_struct_statistics WHERE \"in\" = -20000000", 3L, 3L);
        this.testCountQuery("SELECT count(*) FROM parquet_struct_statistics WHERE byt = 42", 3L, 3L);
        this.testCountQuery("SELECT count(*) FROM parquet_struct_statistics WHERE fl = 0.123", 3L, 3L);
        this.testCountQuery("SELECT count(*) FROM parquet_struct_statistics WHERE dou = -0.321", 3L, 3L);
        this.testCountQuery("SELECT count(*) FROM parquet_struct_statistics WHERE bool = true", 9L, 9L);
        this.testCountQuery("SELECT count(*) FROM parquet_struct_statistics WHERE bin = X'00 02'", 3L, 9L);
        this.testCountQuery("SELECT count(*) FROM parquet_struct_statistics WHERE dat = DATE '5000-01-01'", 3L, 3L);
        this.testCountQuery("SELECT count(*) FROM parquet_struct_statistics WHERE arr = ARRAY[5]", 3L, 9L);
        this.testCountQuery("SELECT count(*) FROM parquet_struct_statistics WHERE m = MAP(ARRAY[1], ARRAY['a'])", 3L, 9L);
        this.testCountQuery("SELECT count(*) FROM parquet_struct_statistics WHERE row = ROW(2, 'b')", 3L, 9L);
    }

    @Test
    public void testPrimitiveFieldsInsideRowColumnPruning() {
        this.assertResultAndSplitCount("SELECT grandparent.parent1.child1 FROM nested_fields WHERE id > 6", Set.of(Double.valueOf(70.99), Double.valueOf(80.99), Double.valueOf(90.99), Double.valueOf(100.99)), 1L);
        this.assertResultAndSplitCount("SELECT grandparent.parent1.child1 FROM nested_fields WHERE id > 10", Set.of(), 0L);
        this.assertResultAndSplitCount("SELECT grandparent.parent1.child1 FROM nested_fields WHERE parent.child1 > 600", Set.of(Double.valueOf(70.99), Double.valueOf(80.99), Double.valueOf(90.99), Double.valueOf(100.99)), 2L);
        this.assertResultAndSplitCount("SELECT grandparent.parent1.child1 FROM nested_fields WHERE parent.child1 > 1000", Set.of(), 2L);
    }

    @Test
    public void testJsonStatisticsPruningUppercaseColumn() {
        this.testCountQuery("SELECT count(*) FROM uppercase_columns_json_statistics WHERE blah = 2", 1L, 1L);
        this.testCountQuery("SELECT count(*) FROM uppercase_columns_json_statistics WHERE blah = 3", 2L, 2L);
        this.testCountQuery("SELECT count(*) FROM uppercase_columns_json_statistics WHERE blah <= 10", 8L, 3L);
    }

    @Test
    public void testStructStatisticsPruningUppercaseColumn() {
        this.testCountQuery("SELECT count(*) FROM uppercase_columns_struct_statistics WHERE blah = 2", 1L, 1L);
        this.testCountQuery("SELECT count(*) FROM uppercase_columns_struct_statistics WHERE blah = 3", 2L, 2L);
        this.testCountQuery("SELECT count(*) FROM uppercase_columns_struct_statistics WHERE blah <= 10", 8L, 3L);
    }

    private void testCountQuery(@Language(value="SQL") String sql, long expectedRowCount, long expectedSplitCount) {
        this.assertResultAndSplitCount(sql, Set.of(Long.valueOf(expectedRowCount)), expectedSplitCount);
    }

    private void assertResultAndSplitCount(String query, Set<?> expectedResultColumn, long expectedSplits) {
        this.assertResultAndSplitCount(query, (MaterializedResult result) -> Assertions.assertThat((Collection)result.getOnlyColumnAsSet()).isEqualTo((Object)expectedResultColumn), expectedSplits);
    }

    private void assertResultAndSplitCount(String query, Consumer<MaterializedResult> resultAssertions, long expectedSplits) {
        if (expectedSplits == 0L) {
            this.assertQueryStats(this.getSession(), query, stats -> Assertions.assertThat((long)this.getOperatorStats((QueryStats)stats).getInputDataSize().toBytes()).isEqualTo(0L), resultAssertions);
        } else {
            this.assertQueryStats(this.getSession(), query, stats -> Assertions.assertThat((long)this.getOperatorStats((QueryStats)stats).getTotalDrivers()).isEqualTo(expectedSplits), resultAssertions);
        }
    }

    private OperatorStats getOperatorStats(QueryStats stats) {
        return (OperatorStats)stats.getOperatorSummaries().stream().filter(summary -> summary.getOperatorType().startsWith("Scan") || summary.getOperatorType().startsWith("TableScan")).collect(MoreCollectors.onlyElement());
    }
}

