/*
 * Decompiled with CFR 0.152.
 */
package com.lancedb.lance.spark.read;

import com.lancedb.lance.spark.TestUtils;
import java.util.function.Function;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

public abstract class BaseSparkConnectorLineItemTest {
    private static SparkSession spark;
    private static String dbPath;
    private static String parquetPath;
    private static Dataset<Row> lanceData;
    private static Dataset<Row> parquetData;

    @BeforeAll
    static void setup() {
        dbPath = System.getenv("DB_PATH");
        parquetPath = System.getenv("PARQUET_PATH");
        Assumptions.assumeTrue((dbPath != null && !dbPath.isEmpty() ? 1 : 0) != 0, (String)"DB_PATH environment variable is not set");
        Assumptions.assumeTrue((parquetPath != null && !parquetPath.isEmpty() ? 1 : 0) != 0, (String)"PARQUET_PATH environment variable is not set");
        spark = SparkSession.builder().appName("spark-lance-connector-test").master("local").config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog").getOrCreate();
        lanceData = spark.read().format("lance").option("path", TestUtils.getDatasetUri(dbPath, "lineitem_10")).load();
        lanceData.createOrReplaceTempView("lance_dataset");
        parquetData = spark.read().parquet(parquetPath);
        parquetData.createOrReplaceTempView("parquet_dataset");
    }

    @AfterAll
    static void tearDown() {
        if (spark != null) {
            spark.stop();
        }
    }

    @Test
    public void test() {
        this.validateResults(data -> data.filter("l_orderkey == 1"));
        this.validateResults(data -> data.filter("l_shipmode = 'TRUCK'").limit(10));
        this.validateResults(data -> data.filter("l_shipmode IS NULL").selectExpr(new String[]{"count(*) as count"}));
        this.validateResults(data -> data.select("l_shipmode", new String[0]).limit(100));
        this.validateResults(data -> data.select("l_orderkey", new String[]{"l_partkey", "l_quantity", "l_extendedprice"}).limit(10));
        this.validateResults(data -> data.groupBy("l_linestatus", new String[0]).avg(new String[]{"l_discount"}));
        this.validateResults(data -> data.groupBy("l_partkey", new String[0]).sum(new String[]{"l_quantity"}).orderBy(new Column[]{functions.desc((String)"sum(l_quantity)")}).limit(5));
        this.validateResults(data -> data.select("l_shipmode", new String[0]).distinct());
        this.validateResults(data -> data.select("l_orderkey", new String[]{"l_comment"}).filter("l_comment LIKE '%express%'"));
        this.validateResults(data -> data.select("l_orderkey", new String[]{"l_partkey", "l_quantity"}));
        this.validateResults(data -> data.filter("l_quantity > 30").select("l_orderkey", new String[]{"l_partkey", "l_quantity"}));
        this.validateResults(data -> data.groupBy("l_returnflag", new String[0]).count());
        this.validateResults(data -> data.filter("l_quantity BETWEEN 5 AND 30"));
        Function<Dataset, Dataset> function = data -> data.select("l_orderkey", new String[]{"l_commitdate"}).orderBy("l_commitdate", new String[0]).limit(10);
        function.apply(lanceData).show();
        function.apply(parquetData).show();
        this.validateResults(data -> data.groupBy("l_orderkey", new String[0]).sum(new String[]{"l_extendedprice"}).orderBy(new Column[]{functions.desc((String)"sum(l_extendedprice)")}));
        Assertions.assertEquals((long)lanceData.count(), (long)parquetData.count());
        Assertions.assertEquals((long)lanceData.select("l_orderkey", new String[0]).count(), (long)parquetData.select("l_orderkey", new String[0]).count());
    }

    @Test
    public void sql() {
        this.validateSQLResults("SELECT * FROM parquet_dataset LIMIT 10");
        this.validateSQLResults("SELECT l_orderkey, l_partkey FROM parquet_dataset LIMIT 10");
        this.validateSQLResults("SELECT l_extendedprice, l_discount, l_tax FROM parquet_dataset LIMIT 10");
        this.validateSQLResults("SELECT l_shipmode, COUNT(*) AS count FROM parquet_dataset GROUP BY l_shipmode");
        this.validateSQLResults("SELECT l_orderkey, SUM(l_extendedprice) AS total_extendedprice FROM parquet_dataset GROUP BY l_orderkey ORDER BY total_extendedprice DESC LIMIT 10");
        this.validateSQLResults("SELECT l_suppkey, SUM(l_tax) AS total_tax FROM parquet_dataset GROUP BY l_suppkey ORDER BY total_tax DESC LIMIT 5");
        this.validateSQLResults("SELECT l_orderkey, year(l_shipdate) AS ship_year FROM parquet_dataset GROUP BY l_orderkey, ship_year ORDER BY ship_year LIMIT 10");
        this.validateSQLResults("SELECT l_orderkey, l_partkey, l_quantity FROM parquet_dataset WHERE l_quantity IS NULL");
        this.validateSQLResults("SELECT * FROM parquet_dataset WHERE (l_quantity > 30) AND (l_comment IS NOT NULL)");
    }

    private void validateResults(Function<Dataset<Row>, Dataset<Row>> operation) {
        Dataset<Row> resultLance = operation.apply(lanceData);
        Dataset<Row> resultParquet = operation.apply(parquetData);
        Assertions.assertEquals((Object)resultParquet.collectAsList(), (Object)resultLance.collectAsList(), (String)"Results differ between Lance and Parquet datasets");
    }

    private void validateSQLResults(String sqlQuery) {
        Dataset resultLance = spark.sql(sqlQuery.replace("parquet_dataset", "lance_dataset"));
        Dataset resultParquet = spark.sql(sqlQuery);
        Assertions.assertEquals((Object)resultParquet.collectAsList(), (Object)resultLance.collectAsList(), (String)("Results differ between Lance and Parquet datasets for query: " + sqlQuery));
    }
}

