/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iceberg.spark.source;

import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.avro.generic.GenericData;
import org.apache.hadoop.conf.Configuration;
import org.apache.iceberg.Files;
import org.apache.iceberg.Parameter;
import org.apache.iceberg.ParameterizedTestExtension;
import org.apache.iceberg.Parameters;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.Table;
import org.apache.iceberg.avro.Avro;
import org.apache.iceberg.avro.AvroIterable;
import org.apache.iceberg.hadoop.HadoopTables;
import org.apache.iceberg.io.CloseableIterator;
import org.apache.iceberg.io.FileAppender;
import org.apache.iceberg.io.InputFile;
import org.apache.iceberg.io.OutputFile;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.spark.data.ParameterizedAvroDataTest;
import org.apache.iceberg.spark.data.RandomData;
import org.apache.iceberg.spark.data.SparkAvroReader;
import org.apache.iceberg.spark.data.TestHelpers;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkException;
import org.apache.spark.TaskContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.MapPartitionsFunction;
import org.apache.spark.sql.DataFrameWriter;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.assertj.core.api.AbstractBooleanAssert;
import org.assertj.core.api.AbstractStringAssert;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.Assumptions;
import org.assertj.core.api.ListAssert;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;

@ExtendWith(value={ParameterizedTestExtension.class})
public class TestDataFrameWrites
extends ParameterizedAvroDataTest {
    private static final Configuration CONF = new Configuration();
    @Parameter
    private String format;
    private static SparkSession spark = null;
    private static JavaSparkContext sc = null;
    private Map<String, String> tableProperties;
    private StructType sparkSchema = new StructType(new StructField[]{new StructField("optionalField", DataTypes.StringType, true, Metadata.empty()), new StructField("requiredField", DataTypes.StringType, false, Metadata.empty())});
    private Schema icebergSchema = new Schema(new Types.NestedField[]{Types.NestedField.optional((int)1, (String)"optionalField", (Type)Types.StringType.get()), Types.NestedField.required((int)2, (String)"requiredField", (Type)Types.StringType.get())});
    private List<String> data0 = Arrays.asList("{\"optionalField\": \"a1\", \"requiredField\": \"bid_001\"}", "{\"optionalField\": \"a2\", \"requiredField\": \"bid_002\"}");
    private List<String> data1 = Arrays.asList("{\"optionalField\": \"d1\", \"requiredField\": \"bid_101\"}", "{\"optionalField\": \"d2\", \"requiredField\": \"bid_102\"}", "{\"optionalField\": \"d3\", \"requiredField\": \"bid_103\"}", "{\"optionalField\": \"d4\", \"requiredField\": \"bid_104\"}");

    @Parameters(name="format = {0}")
    public static Collection<String> parameters() {
        return Arrays.asList("parquet", "avro", "orc");
    }

    @BeforeAll
    public static void startSpark() {
        spark = SparkSession.builder().master("local[2]").getOrCreate();
        sc = JavaSparkContext.fromSparkContext((SparkContext)spark.sparkContext());
    }

    @AfterAll
    public static void stopSpark() {
        SparkSession currentSpark = spark;
        spark = null;
        sc = null;
        currentSpark.stop();
    }

    @Override
    protected void writeAndValidate(Schema schema) throws IOException {
        File location = this.createTableFolder();
        Table table = this.createTable(schema, location);
        this.writeAndValidateWithLocations(table, location, new File(location, "data"));
    }

    @TestTemplate
    public void testWriteWithCustomDataLocation() throws IOException {
        File location = this.createTableFolder();
        File tablePropertyDataLocation = this.temp.resolve("test-table-property-data-dir").toFile();
        Table table = this.createTable(new Schema(SUPPORTED_PRIMITIVES.fields()), location);
        table.updateProperties().set("write.data.path", tablePropertyDataLocation.getAbsolutePath()).commit();
        this.writeAndValidateWithLocations(table, location, tablePropertyDataLocation);
    }

    private File createTableFolder() throws IOException {
        File parent = this.temp.resolve("parquet").toFile();
        File location = new File(parent, "test");
        ((AbstractBooleanAssert)Assertions.assertThat((boolean)location.mkdirs()).as("Mkdir should succeed", new Object[0])).isTrue();
        return location;
    }

    private Table createTable(Schema schema, File location) {
        HadoopTables tables = new HadoopTables(CONF);
        return tables.create(schema, PartitionSpec.unpartitioned(), location.toString());
    }

    private void writeAndValidateWithLocations(Table table, File location, File expectedDataDir) throws IOException {
        Schema tableSchema = table.schema();
        table.updateProperties().set("write.format.default", this.format).commit();
        Iterable<GenericData.Record> expected = RandomData.generate(tableSchema, 100, 0L);
        this.writeData(expected, tableSchema, location.toString());
        table.refresh();
        List<Row> actual = this.readTable(location.toString());
        Iterator<GenericData.Record> expectedIter = expected.iterator();
        Iterator<Row> actualIter = actual.iterator();
        while (expectedIter.hasNext() && actualIter.hasNext()) {
            TestHelpers.assertEqualsSafe(tableSchema.asStruct(), expectedIter.next(), actualIter.next());
        }
        ((AbstractBooleanAssert)Assertions.assertThat((boolean)actualIter.hasNext()).as("Both iterators should be exhausted", new Object[0])).isEqualTo(expectedIter.hasNext());
        table.currentSnapshot().addedDataFiles(table.io()).forEach(dataFile -> {
            AbstractStringAssert cfr_ignored_0 = (AbstractStringAssert)((AbstractStringAssert)Assertions.assertThat((String)URI.create(dataFile.path().toString()).getPath()).as(String.format("File should have the parent directory %s, but has: %s.", expectedDataDir.getAbsolutePath(), dataFile.path()), new Object[0])).startsWith((CharSequence)expectedDataDir.getAbsolutePath());
        });
    }

    private List<Row> readTable(String location) {
        Dataset result = spark.read().format("iceberg").load(location);
        return result.collectAsList();
    }

    private void writeData(Iterable<GenericData.Record> records, Schema schema, String location) throws IOException {
        Dataset<Row> df = this.createDataset(records, schema);
        DataFrameWriter writer = df.write().format("iceberg").mode("append");
        writer.save(location);
    }

    private void writeDataWithFailOnPartition(Iterable<GenericData.Record> records, Schema schema, String location) throws IOException, SparkException {
        int numPartitions = 10;
        int partitionToFail = new Random().nextInt(10);
        MapPartitionsFunction & Serializable failOnFirstPartitionFunc = (MapPartitionsFunction & Serializable)input -> {
            int partitionId = TaskContext.getPartitionId();
            if (partitionId == partitionToFail) {
                throw new SparkException(String.format("Intended exception in partition %d !", partitionId));
            }
            return input;
        };
        Dataset df = this.createDataset(records, schema).repartition(10).mapPartitions((MapPartitionsFunction)failOnFirstPartitionFunc, Encoders.row((StructType)SparkSchemaUtil.convert((Schema)schema)));
        Dataset convertedDf = df.sqlContext().createDataFrame(df.rdd(), SparkSchemaUtil.convert((Schema)schema));
        DataFrameWriter writer = convertedDf.write().format("iceberg").mode("append");
        writer.save(location);
    }

    private Dataset<Row> createDataset(Iterable<GenericData.Record> records, Schema schema) throws IOException {
        File testFile = File.createTempFile("junit", null, this.temp.toFile());
        ((AbstractBooleanAssert)Assertions.assertThat((boolean)testFile.delete()).as("Delete should succeed", new Object[0])).isTrue();
        try (FileAppender writer = Avro.write((OutputFile)Files.localOutput((File)testFile)).schema(schema).named("test").build();){
            for (GenericData.Record rec : records) {
                writer.add((Object)rec);
            }
        }
        ArrayList rows = Lists.newArrayList();
        Object object = null;
        try (AvroIterable reader = Avro.read((InputFile)Files.localInput((File)testFile)).createReaderFunc(SparkAvroReader::new).project(schema).build();){
            Iterator<GenericData.Record> recordIter = records.iterator();
            CloseableIterator readIter = reader.iterator();
            while (recordIter.hasNext() && readIter.hasNext()) {
                InternalRow row = (InternalRow)readIter.next();
                TestHelpers.assertEqualsUnsafe(schema.asStruct(), recordIter.next(), row);
                rows.add(row);
            }
            ((AbstractBooleanAssert)Assertions.assertThat((boolean)readIter.hasNext()).as("Both iterators should be exhausted", new Object[0])).isEqualTo(recordIter.hasNext());
        }
        catch (Throwable throwable) {
            object = throwable;
            throw throwable;
        }
        JavaRDD rdd = sc.parallelize((List)rows);
        return spark.internalCreateDataFrame(JavaRDD.toRDD((JavaRDD)rdd), SparkSchemaUtil.convert((Schema)schema), false);
    }

    @TestTemplate
    public void testNullableWithWriteOption() throws IOException {
        ((AbstractStringAssert)Assumptions.assumeThat((String)spark.version()).as("Spark 3 rejects writing nulls to a required column", new Object[0])).startsWith((CharSequence)"2");
        File location = this.temp.resolve("parquet").resolve("test").toFile();
        String sourcePath = String.format("%s/nullable_poc/sourceFolder/", location.toString());
        String targetPath = String.format("%s/nullable_poc/targetFolder/", location.toString());
        this.tableProperties = ImmutableMap.of((Object)"write.data.path", (Object)targetPath);
        spark.read().schema(this.sparkSchema).json(JavaSparkContext.fromSparkContext((SparkContext)spark.sparkContext()).parallelize(this.data1)).write().parquet(sourcePath);
        new HadoopTables(spark.sessionState().newHadoopConf()).create(this.icebergSchema, PartitionSpec.builderFor((Schema)this.icebergSchema).identity("requiredField").build(), this.tableProperties, targetPath);
        spark.read().schema(this.sparkSchema).json(JavaSparkContext.fromSparkContext((SparkContext)spark.sparkContext()).parallelize(this.data0)).write().format("iceberg").mode(SaveMode.Append).save(targetPath);
        spark.read().schema(SparkSchemaUtil.convert((Schema)this.icebergSchema)).parquet(sourcePath).write().format("iceberg").option("check-nullability", false).mode(SaveMode.Append).save(targetPath);
        List rows = spark.read().format("iceberg").load(targetPath).collectAsList();
        ((ListAssert)Assumptions.assumeThat((List)rows).as("Should contain 6 rows", new Object[0])).hasSize(6);
    }

    @TestTemplate
    public void testNullableWithSparkSqlOption() throws IOException {
        ((AbstractStringAssert)Assumptions.assumeThat((String)spark.version()).as("Spark 3 rejects writing nulls to a required column", new Object[0])).startsWith((CharSequence)"2");
        File location = this.temp.resolve("parquet").resolve("test").toFile();
        String sourcePath = String.format("%s/nullable_poc/sourceFolder/", location.toString());
        String targetPath = String.format("%s/nullable_poc/targetFolder/", location.toString());
        this.tableProperties = ImmutableMap.of((Object)"write.data.path", (Object)targetPath);
        spark.read().schema(this.sparkSchema).json(JavaSparkContext.fromSparkContext((SparkContext)spark.sparkContext()).parallelize(this.data1)).write().parquet(sourcePath);
        SparkSession newSparkSession = SparkSession.builder().master("local[2]").appName("NullableTest").config("spark.sql.iceberg.check-nullability", false).getOrCreate();
        new HadoopTables(newSparkSession.sessionState().newHadoopConf()).create(this.icebergSchema, PartitionSpec.builderFor((Schema)this.icebergSchema).identity("requiredField").build(), this.tableProperties, targetPath);
        newSparkSession.read().schema(this.sparkSchema).json(JavaSparkContext.fromSparkContext((SparkContext)spark.sparkContext()).parallelize(this.data0)).write().format("iceberg").mode(SaveMode.Append).save(targetPath);
        newSparkSession.read().schema(SparkSchemaUtil.convert((Schema)this.icebergSchema)).parquet(sourcePath).write().format("iceberg").mode(SaveMode.Append).save(targetPath);
        List rows = newSparkSession.read().format("iceberg").load(targetPath).collectAsList();
        ((ListAssert)Assumptions.assumeThat((List)rows).as("Should contain 6 rows", new Object[0])).hasSize(6);
    }

    @TestTemplate
    public void testFaultToleranceOnWrite() throws IOException {
        File location = this.createTableFolder();
        Schema schema = new Schema(SUPPORTED_PRIMITIVES.fields());
        Table table = this.createTable(schema, location);
        Iterable<GenericData.Record> records = RandomData.generate(schema, 100, 0L);
        this.writeData(records, schema, location.toString());
        table.refresh();
        Snapshot snapshotBeforeFailingWrite = table.currentSnapshot();
        List<Row> resultBeforeFailingWrite = this.readTable(location.toString());
        Iterable<GenericData.Record> records2 = RandomData.generate(schema, 100, 0L);
        Assertions.assertThatThrownBy(() -> this.writeDataWithFailOnPartition(records2, schema, location.toString())).isInstanceOf(SparkException.class);
        table.refresh();
        Snapshot snapshotAfterFailingWrite = table.currentSnapshot();
        List<Row> resultAfterFailingWrite = this.readTable(location.toString());
        Assertions.assertThat((Object)snapshotBeforeFailingWrite).isEqualTo((Object)snapshotAfterFailingWrite);
        Assertions.assertThat(resultBeforeFailingWrite).isEqualTo(resultAfterFailingWrite);
    }
}

