/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.tests.product.deltalake;

import io.trino.tempto.assertions.QueryAssert;
import org.assertj.core.api.SoftAssertions;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.util.Collection;
import java.util.List;
import java.util.stream.Stream;

import static io.trino.tempto.assertions.QueryAssert.assertQueryFailure;
import static io.trino.tempto.assertions.QueryAssert.assertThat;
import static io.trino.tests.product.TestGroups.DELTA_LAKE_DATABRICKS;
import static io.trino.tests.product.TestGroups.PROFILE_SPECIFIC_TESTS;
import static io.trino.tests.product.hive.util.TemporaryHiveTable.randomTableSuffix;
import static io.trino.tests.product.utils.QueryExecutors.onDelta;
import static io.trino.tests.product.utils.QueryExecutors.onTrino;
import static java.lang.String.format;
import static java.util.function.Predicate.not;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;

public class TestDeltaLakeWriteDatabricksCompatibility
        extends BaseTestDeltaLakeS3Storage
{
    @Test(groups = {DELTA_LAKE_DATABRICKS, PROFILE_SPECIFIC_TESTS})
    public void testUpdateCompatibility()
    {
        String tableName = "test_update_compatibility_" + randomTableSuffix();

        onDelta().executeQuery(format(
                "CREATE TABLE default.%1$s (a int, b int, c int) USING DELTA LOCATION '%2$s%1$s'",
                tableName,
                getBaseLocation()));

        try {
            onDelta().executeQuery("INSERT INTO default." + tableName + " VALUES (1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6), (5, 6, 7)");
            onTrino().executeQuery("UPDATE delta.default." + tableName + " SET b = b * 2 WHERE a % 2 = 1");

            List<QueryAssert.Row> expectedRows = List.of(
                    row(1, 4, 3),
                    row(2, 3, 4),
                    row(3, 8, 5),
                    row(4, 5, 6),
                    row(5, 12, 7));

            assertThat(onDelta().executeQuery("SELECT * FROM default." + tableName))
                    .containsOnly(expectedRows);
            assertThat(onTrino().executeQuery("SELECT * FROM delta.default." + tableName))
                    .containsOnly(expectedRows);
        }
        finally {
            onDelta().executeQuery("DROP TABLE default." + tableName);
        }
    }

    @Test(groups = {DELTA_LAKE_DATABRICKS, PROFILE_SPECIFIC_TESTS})
    public void testDeleteCompatibility()
    {
        String tableName = "test_delete_compatibility_" + randomTableSuffix();

        onDelta().executeQuery(format(
                "CREATE TABLE default.%1$s (a int, b int) USING DELTA LOCATION '%2$s%1$s'",
                tableName,
                getBaseLocation()));

        try {
            onDelta().executeQuery("INSERT INTO default." + tableName + " VALUES (1, 2), (2, 3), (3, 4), (4, 5), (5, 6)");
            onTrino().executeQuery("DELETE FROM delta.default." + tableName + " WHERE a % 2 = 0");

            List<QueryAssert.Row> expectedRows = List.of(
                    row(1, 2),
                    row(3, 4),
                    row(5, 6));

            assertThat(onDelta().executeQuery("SELECT * FROM default." + tableName))
                    .containsOnly(expectedRows);
            assertThat(onTrino().executeQuery("SELECT * FROM delta.default." + tableName))
                    .containsOnly(expectedRows);
        }
        finally {
            onDelta().executeQuery("DROP TABLE default." + tableName);
        }
    }

    @Test(groups = {DELTA_LAKE_DATABRICKS, PROFILE_SPECIFIC_TESTS})
    public void testDeleteOnPartitionedTableCompatibility()
    {
        String tableName = "test_delete_on_partitioned_table_compatibility_" + randomTableSuffix();

        onDelta().executeQuery(format(
                "CREATE TABLE default.%1$s (a int, b int) USING DELTA LOCATION '%2$s%1$s' PARTITIONED BY (b)",
                tableName,
                getBaseLocation()));

        try {
            onDelta().executeQuery("INSERT INTO default." + tableName + " VALUES (1, 2), (2, 3), (3, 4), (4, 5), (5, 6)");
            onTrino().executeQuery("DELETE FROM delta.default." + tableName + " WHERE a % 2 = 0");

            List<QueryAssert.Row> expectedRows = List.of(
                    row(1, 2),
                    row(3, 4),
                    row(5, 6));

            assertThat(onDelta().executeQuery("SELECT * FROM default." + tableName))
                    .containsOnly(expectedRows);
            assertThat(onTrino().executeQuery("SELECT * FROM delta.default." + tableName))
                    .containsOnly(expectedRows);
        }
        finally {
            onDelta().executeQuery("DROP TABLE default." + tableName);
        }
    }

    @Test(groups = {DELTA_LAKE_DATABRICKS, PROFILE_SPECIFIC_TESTS})
    public void testDeleteOnPartitionKeyCompatibility()
    {
        String tableName = "test_delete_on_partitioned_table_compatibility_" + randomTableSuffix();

        onDelta().executeQuery(format(
                "CREATE TABLE default.%1$s (a int, b int) USING DELTA LOCATION '%2$s%1$s' PARTITIONED BY (b)",
                tableName,
                getBaseLocation()));

        try {
            onDelta().executeQuery("INSERT INTO default." + tableName + " VALUES (1, 2), (2, 3), (3, 4), (4, 5), (5, 6)");
            onTrino().executeQuery("DELETE FROM delta.default." + tableName + " WHERE b % 2 = 0");

            List<QueryAssert.Row> expectedRows = List.of(row(2, 3), row(4, 5));

            assertThat(onDelta().executeQuery("SELECT * FROM default." + tableName))
                    .containsOnly(expectedRows);
            assertThat(onTrino().executeQuery("SELECT * FROM delta.default." + tableName))
                    .containsOnly(expectedRows);
        }
        finally {
            onDelta().executeQuery("DROP TABLE default." + tableName);
        }
    }

    // Test partition case sensitivity when updating
    @Test(groups = {DELTA_LAKE_DATABRICKS, PROFILE_SPECIFIC_TESTS}, dataProvider = "partition_column_names")
    public void testCaseUpdateInPartition(String partitionColumn)
    {
        try (CaseTestTable table = new CaseTestTable("update_case_compat", partitionColumn, List.of(
                row(1, 1, 0),
                row(2, 2, 0),
                row(3, 3, 1)))) {
            onTrino().executeQuery(format("UPDATE delta.default.%s SET upper = 0 WHERE lower = 1", table.name()));

            assertTable(table, table.rows().map(row -> row.lower() == 1 ? row.withUpper(0) : row));
        }
    }

    // Test that the correct error is generated when attempting to update the partition columns
    @Test(groups = {DELTA_LAKE_DATABRICKS, PROFILE_SPECIFIC_TESTS}, dataProvider = "partition_column_names")
    public void testCaseUpdatePartitionColumnFails(String partitionColumn)
    {
        try (CaseTestTable table = new CaseTestTable("update_case_compat", partitionColumn, List.of(row(1, 1, 1)))) {
            // TODO: The test fails for uppercase columns because the statement analyzer compares the column name case-sensitively.
            //   Remove the part of the regex after the '|' once that's changed.
            assertQueryFailure(() -> onTrino().executeQuery(format("UPDATE delta.default.%s SET %s = 0 WHERE lower = 1", table.name(), partitionColumn)))
                    .hasMessageMatching(".*(Updating table partition columns is not supported|The UPDATE SET target column .* doesn't exist)");
        }
    }

    // Delete within a partition
    @Test(groups = {DELTA_LAKE_DATABRICKS, PROFILE_SPECIFIC_TESTS}, dataProvider = "partition_column_names")
    public void testCaseDeletePartialPartition(String partitionColumn)
    {
        try (CaseTestTable table = new CaseTestTable("delete_case_compat", partitionColumn, List.of(
                row(1, 1, 0),
                row(2, 2, 0),
                row(3, 3, 1)))) {
            onTrino().executeQuery(format("DELETE FROM delta.default.%s WHERE lower = 1", table.name()));
            assertTable(table, table.rows().filter(not(row -> row.lower() == 1)));
        }
    }

    // Delete an entire partition
    @Test(groups = {DELTA_LAKE_DATABRICKS, PROFILE_SPECIFIC_TESTS}, dataProvider = "partition_column_names")
    public void testCaseDeleteEntirePartition(String partitionColumn)
    {
        try (CaseTestTable table = new CaseTestTable("delete_case_compat", partitionColumn, List.of(
                row(1, 1, 0),
                row(2, 2, 0),
                row(3, 3, 1)))) {
            onTrino().executeQuery(format("DELETE FROM delta.default.%s WHERE %s = 0", table.name(), partitionColumn));
            assertTable(table, table.rows().filter(not(row -> row.partition() == 0)));
        }
    }

    @DataProvider(name = "partition_column_names")
    public static Object[][] partitionColumns()
    {
        return new Object[][] {{"downpart"}, {"UPPART"}};
    }

    private static QueryAssert.Row row(Integer a, Integer b)
    {
        return QueryAssert.Row.row(a, b);
    }

    private static TestRow row(Integer lower, Integer upper, Integer partition)
    {
        return new TestRow(lower, upper, partition);
    }

    private static void assertTable(CaseTestTable table, Stream<? extends QueryAssert.Row> expectedRows)
    {
        assertTable(table, expectedRows.collect(toList()));
    }

    private static void assertTable(CaseTestTable table, List<QueryAssert.Row> expectedRows)
    {
        SoftAssertions softly = new SoftAssertions();

        softly.check(() ->
                assertThat(onDelta().executeQuery("SHOW COLUMNS IN " + table.name()))
                        .as("Correct columns after update")
                        .containsOnly(table.columns().stream().map(QueryAssert.Row::row).collect(toList())));

        softly.check(() ->
                assertThat(onDelta().executeQuery("SELECT * FROM default." + table.name()))
                        .as("Data accessible via Databricks")
                        .containsOnly(expectedRows));

        softly.check(() ->
                assertThat(onTrino().executeQuery("SELECT * FROM delta.default." + table.name()))
                        .as("Data accessible via Trino")
                        .containsOnly(expectedRows));

        softly.assertAll();
    }

    private String getBaseLocation()
    {
        return "s3://" + bucketName + "/databricks-compatibility-test-";
    }

    /**
     * Creates a test table with three integer columns.
     *
     * <p>The first column is named {@code lower}, the second {@code UPPER},
     * and the third column is named according to the {@code partitionColumnName}
     * parameter. The table is partitioned on the third column.
     */
    private class CaseTestTable
            implements AutoCloseable
    {
        private final String name;
        private final List<String> columns;
        private final Collection<TestRow> rows;

        CaseTestTable(String namePrefix, String partitionColumnName, Collection<TestRow> rows)
        {
            this.name = namePrefix + "_" + randomTableSuffix();
            this.columns = List.of("lower", "UPPER", partitionColumnName);
            this.rows = List.copyOf(rows);

            onDelta().executeQuery(format(
                    "CREATE TABLE default.%1$s (lower int, UPPER int, %3$s int)\n"
                            + "USING DELTA\n"
                            + "PARTITIONED BY (%3$s)\n"
                            + "LOCATION '%2$s%1$s'\n",
                    name,
                    getBaseLocation(),
                    partitionColumnName));

            onDelta().executeQuery(format(
                    "INSERT INTO default.%s VALUES %s",
                    name,
                    rows.stream().map(TestRow::asValues).collect(joining(", "))));
        }

        String name()
        {
            return name;
        }

        List<String> columns()
        {
            return columns;
        }

        Stream<TestRow> rows()
        {
            return rows.stream();
        }

        @Override
        public void close()
        {
            onDelta().executeQuery("DROP TABLE default." + name);
        }
    }

    private static class TestRow
            extends QueryAssert.Row
    {
        private Integer lower;
        private Integer upper;
        private Integer partition;

        private TestRow(Integer lower, Integer upper, Integer partition)
        {
            super(List.of(lower, upper, partition));
            this.lower = lower;
            this.upper = upper;
            this.partition = partition;
        }

        public Integer lower()
        {
            return lower;
        }

        public Integer upper()
        {
            return upper;
        }

        public Integer partition()
        {
            return partition;
        }

        public TestRow withLower(Integer newValue)
        {
            return new TestRow(newValue, upper, partition);
        }

        public TestRow withUpper(Integer newValue)
        {
            return new TestRow(lower, newValue, partition);
        }

        public TestRow withPartition(Integer newValue)
        {
            return new TestRow(lower, upper, newValue);
        }

        public String asValues()
        {
            return format("(%s, %s, %s)", lower(), upper(), partition());
        }
    }
}
