/*
 * Decompiled with CFR 0.152.
 */
package com.lancedb.lance.spark;

import java.nio.file.Path;
import java.util.ArrayList;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import scala.collection.Seq;

public abstract class BaseFixedSizeListDataFrameTest {
    @TempDir
    Path tempDir;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testDataFrameWriteAndReadWithFixedSizeList() {
        String catalogName = "lance_test";
        SparkSession spark = SparkSession.builder().appName("dataframe-fixedsizelist-test").master("local[*]").config("spark.sql.catalog." + catalogName, "com.lancedb.lance.spark.LanceNamespaceSparkCatalog").config("spark.sql.catalog." + catalogName + ".impl", "dir").config("spark.sql.catalog." + catalogName + ".root", this.tempDir.toString()).getOrCreate();
        try {
            String tableName = "df_vector_table";
            Metadata vectorMetadata = Metadata.fromJson((String)"{\"arrow.fixed-size-list.size\":128}");
            StructType schema = new StructType(new StructField[]{DataTypes.createStructField((String)"id", (DataType)DataTypes.IntegerType, (boolean)false), DataTypes.createStructField((String)"text", (DataType)DataTypes.StringType, (boolean)true), new StructField("embeddings", (DataType)DataTypes.createArrayType((DataType)DataTypes.FloatType, (boolean)false), false, vectorMetadata)});
            ArrayList<Row> rows = new ArrayList<Row>();
            for (int i = 0; i < 10; ++i) {
                float[] vector = new float[128];
                for (int j = 0; j < 128; ++j) {
                    vector[j] = (float)i * 0.01f + (float)j * 0.001f;
                }
                rows.add(RowFactory.create((Object[])new Object[]{i, "text_" + i, vector}));
            }
            Dataset df = spark.createDataFrame(rows, schema);
            df.writeTo(catalogName + ".default." + tableName).using("lance").createOrReplace();
            Dataset result = spark.table(catalogName + ".default." + tableName);
            Assertions.assertEquals((long)10L, (long)result.count(), (String)"Should have 10 rows");
            Row firstRow = (Row)result.first();
            Assertions.assertEquals((int)0, (int)firstRow.getInt(0));
            Assertions.assertEquals((Object)"text_0", (Object)firstRow.getString(1));
            Seq embeddings = firstRow.getSeq(2);
            Assertions.assertEquals((int)128, (int)embeddings.size(), (String)"Embeddings should have 128 elements");
            for (int i = 0; i < 10; ++i) {
                float expected = (float)i * 0.001f;
                Assertions.assertEquals((float)expected, (float)((Float)embeddings.apply(i)).floatValue(), (float)1.0E-4f, (String)("Value mismatch at index " + i));
            }
            spark.sql("DROP TABLE IF EXISTS " + catalogName + ".default." + tableName);
        }
        finally {
            spark.stop();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testDataFrameMultipleVectorColumns() {
        String catalogName = "lance_test";
        SparkSession spark = SparkSession.builder().appName("dataframe-multi-vector-test").master("local[*]").config("spark.sql.catalog." + catalogName, "com.lancedb.lance.spark.LanceNamespaceSparkCatalog").config("spark.sql.catalog." + catalogName + ".impl", "dir").config("spark.sql.catalog." + catalogName + ".root", this.tempDir.toString()).getOrCreate();
        try {
            String tableName = "df_multi_vector";
            Metadata vec32Metadata = Metadata.fromJson((String)"{\"arrow.fixed-size-list.size\":32}");
            Metadata vec128Metadata = Metadata.fromJson((String)"{\"arrow.fixed-size-list.size\":128}");
            Metadata vec256Metadata = Metadata.fromJson((String)"{\"arrow.fixed-size-list.size\":256}");
            StructType schema = new StructType(new StructField[]{DataTypes.createStructField((String)"id", (DataType)DataTypes.IntegerType, (boolean)false), DataTypes.createStructField((String)"name", (DataType)DataTypes.StringType, (boolean)true), new StructField("small_embedding", (DataType)DataTypes.createArrayType((DataType)DataTypes.FloatType, (boolean)false), false, vec32Metadata), new StructField("medium_embedding", (DataType)DataTypes.createArrayType((DataType)DataTypes.FloatType, (boolean)false), false, vec128Metadata), new StructField("large_embedding", (DataType)DataTypes.createArrayType((DataType)DataTypes.FloatType, (boolean)false), false, vec256Metadata)});
            ArrayList<Row> rows = new ArrayList<Row>();
            for (int i = 0; i < 5; ++i) {
                int j;
                float[] smallVec = new float[32];
                float[] mediumVec = new float[128];
                float[] largeVec = new float[256];
                for (j = 0; j < 32; ++j) {
                    smallVec[j] = (float)i * 0.01f + (float)j * 0.001f;
                }
                for (j = 0; j < 128; ++j) {
                    mediumVec[j] = (float)i * 0.005f + (float)j * 5.0E-4f;
                }
                for (j = 0; j < 256; ++j) {
                    largeVec[j] = (float)i * 0.002f + (float)j * 2.0E-4f;
                }
                rows.add(RowFactory.create((Object[])new Object[]{i, "entity_" + i, smallVec, mediumVec, largeVec}));
            }
            Dataset df = spark.createDataFrame(rows, schema);
            df.writeTo(catalogName + ".default." + tableName).using("lance").createOrReplace();
            Dataset result = spark.table(catalogName + ".default." + tableName);
            Assertions.assertEquals((long)5L, (long)result.count(), (String)"Should have 5 rows");
            Row firstRow = (Row)result.first();
            Seq smallEmb = firstRow.getSeq(2);
            Seq mediumEmb = firstRow.getSeq(3);
            Seq largeEmb = firstRow.getSeq(4);
            Assertions.assertEquals((int)32, (int)smallEmb.size(), (String)"Small embedding should have 32 elements");
            Assertions.assertEquals((int)128, (int)mediumEmb.size(), (String)"Medium embedding should have 128 elements");
            Assertions.assertEquals((int)256, (int)largeEmb.size(), (String)"Large embedding should have 256 elements");
            spark.sql("DROP TABLE IF EXISTS " + catalogName + ".default." + tableName);
        }
        finally {
            spark.stop();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testDataFrameMixedPrecisionVectors() {
        String catalogName = "lance_test";
        SparkSession spark = SparkSession.builder().appName("dataframe-mixed-precision-test").master("local[*]").config("spark.sql.catalog." + catalogName, "com.lancedb.lance.spark.LanceNamespaceSparkCatalog").config("spark.sql.catalog." + catalogName + ".impl", "dir").config("spark.sql.catalog." + catalogName + ".root", this.tempDir.toString()).getOrCreate();
        try {
            String tableName = "df_mixed_precision";
            Metadata floatVecMetadata = Metadata.fromJson((String)"{\"arrow.fixed-size-list.size\":64}");
            Metadata doubleVecMetadata = Metadata.fromJson((String)"{\"arrow.fixed-size-list.size\":64}");
            StructType schema = new StructType(new StructField[]{DataTypes.createStructField((String)"id", (DataType)DataTypes.IntegerType, (boolean)false), DataTypes.createStructField((String)"label", (DataType)DataTypes.StringType, (boolean)true), new StructField("float_embedding", (DataType)DataTypes.createArrayType((DataType)DataTypes.FloatType, (boolean)false), false, floatVecMetadata), new StructField("double_embedding", (DataType)DataTypes.createArrayType((DataType)DataTypes.DoubleType, (boolean)false), false, doubleVecMetadata)});
            ArrayList<Row> rows = new ArrayList<Row>();
            for (int i = 0; i < 5; ++i) {
                float[] floatVec = new float[64];
                double[] doubleVec = new double[64];
                for (int j = 0; j < 64; ++j) {
                    floatVec[j] = (float)i * 0.1f + (float)j * 0.01f;
                    doubleVec[j] = (double)i * 0.1 + (double)j * 0.01;
                }
                rows.add(RowFactory.create((Object[])new Object[]{i, "label_" + i, floatVec, doubleVec}));
            }
            Dataset df = spark.createDataFrame(rows, schema);
            df.writeTo(catalogName + ".default." + tableName).using("lance").createOrReplace();
            Dataset result = spark.table(catalogName + ".default." + tableName);
            Assertions.assertEquals((long)5L, (long)result.count(), (String)"Should have 5 rows");
            Row firstRow = (Row)result.first();
            Seq floatEmb = firstRow.getSeq(2);
            Seq doubleEmb = firstRow.getSeq(3);
            Assertions.assertEquals((int)64, (int)floatEmb.size());
            Assertions.assertEquals((int)64, (int)doubleEmb.size());
            for (int i = 0; i < 10; ++i) {
                float fVal = ((Float)floatEmb.apply(i)).floatValue();
                double dVal = (Double)doubleEmb.apply(i);
                Assertions.assertEquals((float)((float)i * 0.01f), (float)fVal, (float)1.0E-4f);
                Assertions.assertEquals((double)((double)i * 0.01), (double)dVal, (double)1.0E-7);
            }
            spark.sql("DROP TABLE IF EXISTS " + catalogName + ".default." + tableName);
        }
        finally {
            spark.stop();
        }
    }
}

