/*
 * Decompiled with CFR 0.152.
 */
package com.lancedb.lance.spark.read;

import com.lancedb.lance.spark.TestUtils;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

public abstract class BaseSparkConnectorReadTest {
    private static SparkSession spark;
    private static String dbPath;
    private static Dataset<Row> data;

    @BeforeAll
    static void setup() {
        spark = SparkSession.builder().appName("spark-lance-connector-test").master("local").config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog").getOrCreate();
        dbPath = TestUtils.TestTable1Config.dbPath;
        data = spark.read().format("lance").option("path", TestUtils.getDatasetUri(dbPath, "test_dataset1")).load();
        data.createOrReplaceTempView("test_dataset1");
    }

    @AfterAll
    static void tearDown() {
        if (spark != null) {
            spark.stop();
        }
    }

    private void validateData(Dataset<Row> data, List<List<Long>> expectedValues) {
        List rows = data.collectAsList();
        Assertions.assertEquals((int)expectedValues.size(), (int)rows.size());
        for (int i = 0; i < rows.size(); ++i) {
            Row row = (Row)rows.get(i);
            List<Long> expectedRow = expectedValues.get(i);
            Assertions.assertEquals((int)expectedRow.size(), (int)row.size());
            for (int j = 0; j < expectedRow.size(); ++j) {
                long expectedValue = expectedRow.get(j);
                long actualValue = row.getLong(j);
                Assertions.assertEquals((long)expectedValue, (long)actualValue, (String)("Mismatch at row " + i + " column " + j));
            }
        }
    }

    @Test
    public void readAll() {
        this.validateData(data, TestUtils.TestTable1Config.expectedValues);
    }

    @Test
    public void filter() {
        this.validateData((Dataset<Row>)data.filter("x > 1"), TestUtils.TestTable1Config.expectedValues.stream().filter(row -> (Long)row.get(0) > 1L).collect(Collectors.toList()));
        this.validateData((Dataset<Row>)data.filter("y == 4"), TestUtils.TestTable1Config.expectedValues.stream().filter(row -> (Long)row.get(1) == 4L).collect(Collectors.toList()));
        this.validateData((Dataset<Row>)data.filter("b >= 6"), TestUtils.TestTable1Config.expectedValues.stream().filter(row -> (Long)row.get(2) >= 6L).collect(Collectors.toList()));
        this.validateData((Dataset<Row>)data.filter("c < -1"), TestUtils.TestTable1Config.expectedValues.stream().filter(row -> (Long)row.get(3) < -1L).collect(Collectors.toList()));
        this.validateData((Dataset<Row>)data.filter("c <= -1"), TestUtils.TestTable1Config.expectedValues.stream().filter(row -> (Long)row.get(3) <= -1L).collect(Collectors.toList()));
        this.validateData((Dataset<Row>)data.filter("c == -2"), TestUtils.TestTable1Config.expectedValues.stream().filter(row -> (Long)row.get(3) == -2L).collect(Collectors.toList()));
        this.validateData((Dataset<Row>)data.filter("x > 1").filter("y < 6"), TestUtils.TestTable1Config.expectedValues.stream().filter(row -> (Long)row.get(0) > 1L).filter(row -> (Long)row.get(1) < 6L).collect(Collectors.toList()));
        this.validateData((Dataset<Row>)data.filter("x > 1 and y < 6"), TestUtils.TestTable1Config.expectedValues.stream().filter(row -> (Long)row.get(0) > 1L).filter(row -> (Long)row.get(1) < 6L).collect(Collectors.toList()));
        this.validateData((Dataset<Row>)data.filter("x > 1 or y < 6"), TestUtils.TestTable1Config.expectedValues.stream().filter(row -> (Long)row.get(0) > 1L || (Long)row.get(1) < 6L).collect(Collectors.toList()));
        this.validateData((Dataset<Row>)data.filter("(x >= 1 and x <= 2) or (c >= -2 and c < 0)"), TestUtils.TestTable1Config.expectedValues.stream().filter(row -> (Long)row.get(0) >= 1L && (Long)row.get(0) <= 2L || (Long)row.get(3) >= -2L && (Long)row.get(3) < 0L).collect(Collectors.toList()));
    }

    @Test
    public void select() {
        this.validateData((Dataset<Row>)data.select("y", new String[]{"b"}), TestUtils.TestTable1Config.expectedValues.stream().map(row -> Arrays.asList((Long)row.get(1), (Long)row.get(2))).collect(Collectors.toList()));
    }

    @Test
    public void filterSelect() {
        this.validateData((Dataset<Row>)data.select("y", new String[]{"b"}).filter("y > 3"), TestUtils.TestTable1Config.expectedValues.stream().map(row -> Arrays.asList((Long)row.get(1), (Long)row.get(2))).filter(row -> (Long)row.get(0) > 3L).collect(Collectors.toList()));
    }

    @Test
    public void supportDataSourceLoadPath() {
        Dataset df = spark.read().format("lance").load(TestUtils.getDatasetUri(dbPath, "test_dataset1"));
        this.validateData((Dataset<Row>)df, TestUtils.TestTable1Config.expectedValues);
    }

    @Test
    public void supportBroadcastJoin() {
        Dataset df = spark.read().format("lance").load(TestUtils.getDatasetUri(dbPath, "test_dataset3"));
        df.createOrReplaceTempView("test_dataset3");
        List desc = spark.sql("explain select t1.* from test_dataset1 t1 join test_dataset3 t3 on t1.x = t3.x").collectAsList();
        Assertions.assertEquals((int)1, (int)desc.size());
        Assertions.assertTrue((boolean)((Row)desc.get(0)).getString(0).contains("BroadcastHashJoin"));
    }
}

