/*
 * Decompiled with CFR 0.152.
 */
package com.lancedb.lance.spark.read;

import com.lancedb.lance.spark.TestUtils;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

public abstract class BaseSparkConnectorReadWithRowIdTest {
    private static SparkSession spark;
    private static String dbPath;
    private static Dataset<Row> data;

    @BeforeAll
    static void setup() {
        spark = SparkSession.builder().appName("spark-lance-connector-test").master("local").config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog").getOrCreate();
        dbPath = TestUtils.TestTable1Config.dbPath;
        data = spark.read().format("lance").option("path", TestUtils.getDatasetUri(dbPath, "test_dataset1")).load();
    }

    @AfterAll
    static void tearDown() {
        if (spark != null) {
            spark.stop();
        }
    }

    private void validateData(Dataset<Row> data, List<List<Long>> expectedValues) {
        List rows = data.collectAsList();
        Assertions.assertEquals((int)expectedValues.size(), (int)rows.size());
        for (int i = 0; i < rows.size(); ++i) {
            Row row = (Row)rows.get(i);
            List<Long> expectedRow = expectedValues.get(i);
            Assertions.assertEquals((int)expectedRow.size(), (int)row.size());
            for (int j = 0; j < expectedRow.size(); ++j) {
                long expectedValue = expectedRow.get(j);
                long actualValue = row.getLong(j);
                Assertions.assertEquals((long)expectedValue, (long)actualValue, (String)("Mismatch at row " + i + " column " + j));
            }
        }
    }

    @Test
    public void readAllWithoutRowId() {
        this.validateData(data, TestUtils.TestTable1Config.expectedValues);
    }

    @Test
    public void readAllWithRowId() {
        this.validateData((Dataset<Row>)data.select("x", new String[]{"y", "b", "c", "_rowid"}), TestUtils.TestTable1Config.expectedValuesWithRowId);
    }

    @Test
    public void select() {
        this.validateData((Dataset<Row>)data.select("y", new String[]{"b", "_rowid"}), TestUtils.TestTable1Config.expectedValuesWithRowId.stream().map(row -> Arrays.asList((Long)row.get(1), (Long)row.get(2), (Long)row.get(4))).collect(Collectors.toList()));
    }

    @Test
    public void filterSelect() {
        this.validateData((Dataset<Row>)data.select("y", new String[]{"b", "_rowid"}).filter("y > 3"), TestUtils.TestTable1Config.expectedValuesWithRowId.stream().map(row -> Arrays.asList((Long)row.get(1), (Long)row.get(2), (Long)row.get(4))).filter(row -> (Long)row.get(0) > 3L).collect(Collectors.toList()));
    }

    @Test
    public void filterSelectByRowId() {
        this.validateData((Dataset<Row>)data.select("y", new String[]{"b", "_rowid"}).filter("_rowid > 3"), TestUtils.TestTable1Config.expectedValuesWithRowId.stream().map(row -> Arrays.asList((Long)row.get(1), (Long)row.get(2), (Long)row.get(4))).filter(row -> (Long)row.get(2) > 3L).collect(Collectors.toList()));
    }
}

