/*
 * Decompiled with CFR 0.152.
 */
package com.lancedb.lance.spark.write;

import com.lancedb.lance.spark.TestUtils;
import java.io.File;
import java.lang.reflect.Method;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.api.io.TempDir;

public abstract class BaseSparkConnectorWriteTest {
    private static SparkSession spark;
    private static Dataset<Row> testData;
    @TempDir
    static Path dbPath;

    @BeforeAll
    static void setup() {
        spark = SparkSession.builder().appName("spark-lance-connector-test").master("local").config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog").config("spark.sql.catalog.lance.max_row_per_file", "1").getOrCreate();
        StructType schema = new StructType(new StructField[]{DataTypes.createStructField((String)"id", (DataType)DataTypes.IntegerType, (boolean)false), DataTypes.createStructField((String)"name", (DataType)DataTypes.StringType, (boolean)false), DataTypes.createStructField((String)"address", (DataType)new StructType(new StructField[]{DataTypes.createStructField((String)"city", (DataType)DataTypes.StringType, (boolean)true), DataTypes.createStructField((String)"country", (DataType)DataTypes.StringType, (boolean)true)}), (boolean)true)});
        Row row1 = RowFactory.create((Object[])new Object[]{1, "Alice", RowFactory.create((Object[])new Object[]{"Beijing", "China"})});
        Row row2 = RowFactory.create((Object[])new Object[]{2, "Bob", RowFactory.create((Object[])new Object[]{"New York", "USA"})});
        List<Row> data = Arrays.asList(row1, row2);
        testData = spark.createDataFrame(data, schema);
        testData.createOrReplaceTempView("tmp_view");
    }

    @AfterAll
    static void tearDown() {
        if (spark != null) {
            spark.stop();
        }
    }

    @Test
    public void defaultWrite(TestInfo testInfo) {
        String datasetName = ((Method)testInfo.getTestMethod().get()).getName();
        testData.write().format("lance").option("path", TestUtils.getDatasetUri(dbPath.toString(), datasetName)).save();
        this.validateData(datasetName, 1);
    }

    @Test
    public void errorIfExists(TestInfo testInfo) {
        String datasetName = ((Method)testInfo.getTestMethod().get()).getName();
        testData.write().format("lance").option("path", TestUtils.getDatasetUri(dbPath.toString(), datasetName)).save();
        Assertions.assertThrows(TableAlreadyExistsException.class, () -> testData.write().format("lance").option("path", TestUtils.getDatasetUri(dbPath.toString(), datasetName)).save());
    }

    @Test
    public void append(TestInfo testInfo) {
        String datasetName = ((Method)testInfo.getTestMethod().get()).getName();
        testData.write().format("lance").option("path", TestUtils.getDatasetUri(dbPath.toString(), datasetName)).save();
        testData.write().format("lance").option("path", TestUtils.getDatasetUri(dbPath.toString(), datasetName)).mode("append").save();
        this.validateData(datasetName, 2);
    }

    @Test
    public void appendErrorIfNotExist(TestInfo testInfo) {
        String datasetName = ((Method)testInfo.getTestMethod().get()).getName();
        Assertions.assertThrows(NoSuchTableException.class, () -> testData.write().format("lance").option("path", TestUtils.getDatasetUri(dbPath.toString(), datasetName)).mode("append").save());
    }

    @Test
    public void saveToPath(TestInfo testInfo) {
        String datasetName = ((Method)testInfo.getTestMethod().get()).getName();
        testData.write().format("lance").save(TestUtils.getDatasetUri(dbPath.toString(), datasetName));
        this.validateData(datasetName, 1);
    }

    @Test
    public void overwrite(TestInfo testInfo) {
        String datasetName = ((Method)testInfo.getTestMethod().get()).getName();
        testData.write().format("lance").option("path", TestUtils.getDatasetUri(dbPath.toString(), datasetName)).save();
        testData.write().format("lance").option("path", TestUtils.getDatasetUri(dbPath.toString(), datasetName)).mode("overwrite").save();
        this.validateData(datasetName, 1);
    }

    @Test
    public void appendAfterOverwrite(TestInfo testInfo) {
        String datasetName = ((Method)testInfo.getTestMethod().get()).getName();
        testData.write().format("lance").option("path", TestUtils.getDatasetUri(dbPath.toString(), datasetName)).save();
        testData.write().format("lance").option("path", TestUtils.getDatasetUri(dbPath.toString(), datasetName)).mode("overwrite").save();
        testData.write().format("lance").option("path", TestUtils.getDatasetUri(dbPath.toString(), datasetName)).mode("append").save();
        this.validateData(datasetName, 2);
    }

    @Test
    public void writeMultiFiles(TestInfo testInfo) {
        String datasetName = ((Method)testInfo.getTestMethod().get()).getName();
        String filePath = TestUtils.getDatasetUri(dbPath.toString(), datasetName);
        testData.write().format("lance").option("path", filePath).save();
        this.validateData(datasetName, 1);
        File directory = new File(filePath + "/data");
        Assertions.assertEquals((int)2, (int)directory.listFiles().length);
    }

    @Test
    public void writeEmptyTaskFiles(TestInfo testInfo) {
        String datasetName = ((Method)testInfo.getTestMethod().get()).getName();
        String filePath = TestUtils.getDatasetUri(dbPath.toString(), datasetName);
        testData.repartition(4).write().format("lance").option("path", filePath).save();
        File directory = new File(filePath + "/data");
        Assertions.assertEquals((int)2, (int)directory.listFiles().length);
    }

    private void validateData(String datasetName, int iteration) {
        Dataset data = spark.read().format("lance").option("path", TestUtils.getDatasetUri(dbPath.toString(), datasetName)).load();
        Assertions.assertEquals((long)(2 * iteration), (long)data.count());
        Assertions.assertEquals((long)iteration, (long)data.filter(functions.col((String)"id").equalTo((Object)1)).count());
        Assertions.assertEquals((long)iteration, (long)data.filter(functions.col((String)"id").equalTo((Object)2)).count());
        Dataset data1 = data.filter(functions.col((String)"id").equalTo((Object)1)).select("name", new String[]{"address"});
        Dataset data2 = data.filter(functions.col((String)"id").equalTo((Object)2)).select("name", new String[]{"address"});
        for (Row row : data1.collectAsList()) {
            Assertions.assertEquals((Object)"Alice", (Object)row.getString(0));
            Assertions.assertEquals((Object)"Beijing", (Object)row.getStruct(1).getString(0));
            Assertions.assertEquals((Object)"China", (Object)row.getStruct(1).getString(1));
        }
        for (Row row : data2.collectAsList()) {
            Assertions.assertEquals((Object)"Bob", (Object)row.getString(0));
            Assertions.assertEquals((Object)"New York", (Object)row.getStruct(1).getString(0));
            Assertions.assertEquals((Object)"USA", (Object)row.getStruct(1).getString(1));
        }
    }

    @Test
    public void dropAndReplaceTable(TestInfo testInfo) {
        String datasetName = ((Method)testInfo.getTestMethod().get()).getName();
        String path = TestUtils.getDatasetUri(dbPath.toString(), datasetName);
        spark.sql("CREATE OR REPLACE TABLE lance.`" + path + "` AS SELECT * FROM tmp_view");
        spark.sql("CREATE OR REPLACE TABLE lance.`" + path + "` AS SELECT * FROM tmp_view");
        spark.sql("DROP TABLE lance.`" + path + "`");
    }
}

