/*
 * Decompiled with CFR 0.152.
 */
package com.lancedb.lance.spark.read;

import java.nio.file.Path;
import java.util.Arrays;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;

public abstract class BaseSparkConnectorAggPushdownTest {
    private static SparkSession spark;
    @TempDir
    static Path tempDir;

    @BeforeAll
    static void setup() {
        spark = SparkSession.builder().appName("LanceAggregatePushDownTest").master("local[*]").config("spark.ui.enabled", "false").config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceNamespaceSparkCatalog").config("spark.sql.catalog.lance.impl", "dir").config("spark.sql.catalog.lance.root", tempDir.toString()).getOrCreate();
    }

    @AfterAll
    static void tearDown() {
        if (spark != null) {
            spark.stop();
        }
    }

    @Test
    public void testCountStarPushDown() throws Exception {
        String tableName = "lance.default.count_test_dataset";
        spark.range(0L, 100L).toDF(new String[]{"id"}).repartition(4).writeTo(tableName).create();
        Dataset lanceDataset = spark.table(tableName);
        lanceDataset.selectExpr(new String[]{"count(*)"}).explain(true);
        Dataset countDataset = lanceDataset.selectExpr(new String[]{"count(*)"});
        Row countRow = (Row)countDataset.first();
        long countFromSelectExpr = countRow.getLong(0);
        long count = lanceDataset.count();
        Assertions.assertEquals((long)100L, (long)countFromSelectExpr, (String)"Count(*) should return 100");
        Assertions.assertEquals((long)100L, (long)count, (String)"Count should return 100 rows");
    }

    @Test
    public void testCountStarWithFilter() throws Exception {
        String tableName = "lance.default.count_filter_test_dataset";
        spark.range(0L, 100L).selectExpr(new String[]{"id", "id % 10 as category", "id * 2 as value"}).repartition(4).writeTo(tableName).create();
        Dataset lanceDataset = spark.table(tableName);
        long filteredCount = lanceDataset.filter("category = 5").count();
        lanceDataset.explain(true);
        Assertions.assertEquals((long)10L, (long)filteredCount, (String)"Filtered count should return 10 rows");
        long complexFilteredCount = lanceDataset.filter("category > 5 AND value < 150").count();
        Assertions.assertEquals((long)28L, (long)complexFilteredCount, (String)"Complex filtered count should return 28 rows");
    }

    @Test
    public void testMultipleAggregates() throws Exception {
        String tableName = "lance.default.multiple_agg_test_dataset";
        spark.range(1L, 101L).selectExpr(new String[]{"id", "id * 10 as value"}).repartition(4).writeTo(tableName).create();
        Dataset lanceDataset = spark.table(tableName);
        Dataset aggregates = lanceDataset.selectExpr(new String[]{"count(*) as cnt", "sum(value) as total", "avg(value) as average"});
        Row result = (Row)aggregates.first();
        Assertions.assertEquals((long)100L, (long)result.getLong(0), (String)"Count should be 100");
        Assertions.assertEquals((long)50500L, (long)result.getLong(1), (String)"Sum should be 50500");
        Assertions.assertEquals((double)505.0, (double)result.getDouble(2), (double)0.001, (String)"Average should be 505");
    }

    @Test
    public void testCountColumnNotPushedDown() throws Exception {
        String tableName = "lance.default.count_column_test_dataset";
        spark.createDataFrame(Arrays.asList(RowFactory.create((Object[])new Object[]{1L, "a"}), RowFactory.create((Object[])new Object[]{2L, null}), RowFactory.create((Object[])new Object[]{3L, "c"}), RowFactory.create((Object[])new Object[]{4L, null}), RowFactory.create((Object[])new Object[]{5L, "e"})), new StructType().add("id", DataTypes.LongType).add("name", DataTypes.StringType)).writeTo(tableName).create();
        spark.catalog().refreshTable(tableName);
        Dataset lanceDataset = spark.table(tableName);
        long countName = ((Row)lanceDataset.selectExpr(new String[]{"count(name)"}).first()).getLong(0);
        Assertions.assertEquals((long)3L, (long)countName, (String)"Count(name) should be 3 (excluding nulls)");
        long countStar = ((Row)lanceDataset.selectExpr(new String[]{"count(*)"}).first()).getLong(0);
        Assertions.assertEquals((long)5L, (long)countStar, (String)"Count(*) should be 5");
    }

    @Test
    public void testCountDistinctNotPushedDown() throws Exception {
        String tableName = "lance.default.count_distinct_test_dataset";
        spark.createDataFrame(Arrays.asList(RowFactory.create((Object[])new Object[]{1L, "a"}), RowFactory.create((Object[])new Object[]{2L, "b"}), RowFactory.create((Object[])new Object[]{3L, "a"}), RowFactory.create((Object[])new Object[]{4L, "b"}), RowFactory.create((Object[])new Object[]{5L, "c"})), new StructType().add("id", DataTypes.LongType).add("category", DataTypes.StringType)).writeTo(tableName).create();
        spark.catalog().refreshTable(tableName);
        Dataset lanceDataset = spark.table(tableName);
        long countDistinct = ((Row)lanceDataset.selectExpr(new String[]{"count(distinct category)"}).first()).getLong(0);
        Assertions.assertEquals((long)3L, (long)countDistinct, (String)"Count(distinct category) should be 3");
    }
}

