/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hudi.integ.testsuite.generator;

import java.io.IOException;
import java.io.Serializable;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.StreamSupport;
import org.apache.avro.generic.GenericRecord;
import org.apache.hudi.common.util.Option;
import org.apache.hudi.integ.testsuite.configuration.DFSDeltaConfig;
import org.apache.hudi.integ.testsuite.configuration.DeltaConfig;
import org.apache.hudi.integ.testsuite.converter.UpdateConverter;
import org.apache.hudi.integ.testsuite.generator.FlexibleSchemaRecordGenerationIterator;
import org.apache.hudi.integ.testsuite.generator.LazyRecordGeneratorIterator;
import org.apache.hudi.integ.testsuite.reader.DFSAvroDeltaInputReader;
import org.apache.hudi.integ.testsuite.reader.DFSDeltaInputReader;
import org.apache.hudi.integ.testsuite.reader.DFSHoodieDatasetInputReader;
import org.apache.hudi.integ.testsuite.writer.DeltaOutputMode;
import org.apache.hudi.integ.testsuite.writer.DeltaWriteStats;
import org.apache.hudi.integ.testsuite.writer.DeltaWriterAdapter;
import org.apache.hudi.integ.testsuite.writer.DeltaWriterFactory;
import org.apache.hudi.keygen.BuiltinKeyGenerator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.storage.StorageLevel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

public class DeltaGenerator
implements Serializable {
    private static Logger log = LoggerFactory.getLogger(DFSHoodieDatasetInputReader.class);
    private DeltaConfig deltaOutputConfig;
    private transient JavaSparkContext jsc;
    private transient SparkSession sparkSession;
    private String schemaStr;
    private List<String> recordRowKeyFieldNames;
    private List<String> partitionPathFieldNames;
    private int batchId;

    public DeltaGenerator(DeltaConfig deltaOutputConfig, JavaSparkContext jsc, SparkSession sparkSession, String schemaStr, BuiltinKeyGenerator keyGenerator) {
        this.deltaOutputConfig = deltaOutputConfig;
        this.jsc = jsc;
        this.sparkSession = sparkSession;
        this.schemaStr = schemaStr;
        this.recordRowKeyFieldNames = keyGenerator.getRecordKeyFields();
        this.partitionPathFieldNames = keyGenerator.getPartitionPathFields();
    }

    public JavaRDD<DeltaWriteStats> writeRecords(JavaRDD<GenericRecord> records) {
        JavaRDD ws = records.mapPartitions((FlatMapFunction & Serializable)itr -> {
            try {
                DeltaWriterAdapter deltaWriterAdapter = DeltaWriterFactory.getDeltaWriterAdapter(this.deltaOutputConfig, this.batchId);
                return Collections.singletonList(deltaWriterAdapter.write(itr)).iterator();
            }
            catch (IOException io) {
                throw new UncheckedIOException(io);
            }
        }).flatMap(List::iterator);
        ++this.batchId;
        return ws;
    }

    public JavaRDD<GenericRecord> generateInserts(DeltaConfig.Config operation) {
        long recordsPerPartition = operation.getNumRecordsInsert();
        int minPayloadSize = operation.getRecordSize();
        JavaRDD inputBatch = this.jsc.parallelize(Collections.EMPTY_LIST).repartition(operation.getNumInsertPartitions()).mapPartitions((FlatMapFunction & Serializable)p -> new LazyRecordGeneratorIterator(new FlexibleSchemaRecordGenerationIterator(recordsPerPartition, minPayloadSize, this.schemaStr, this.partitionPathFieldNames)));
        return inputBatch;
    }

    public JavaRDD<GenericRecord> generateUpdates(DeltaConfig.Config config) throws IOException {
        if (this.deltaOutputConfig.getDeltaOutputMode() == DeltaOutputMode.DFS) {
            JavaRDD<GenericRecord> inserts = null;
            if (config.getNumRecordsInsert() > 0L) {
                inserts = this.generateInserts(config);
            }
            DFSDeltaInputReader deltaInputReader = null;
            JavaRDD adjustedRDD = null;
            if (config.getNumUpsertPartitions() < 1) {
                deltaInputReader = new DFSAvroDeltaInputReader(this.sparkSession, this.schemaStr, ((DFSDeltaConfig)this.deltaOutputConfig).getDeltaBasePath(), (Option<String>)Option.empty(), (Option<String>)Option.empty());
                adjustedRDD = deltaInputReader.read(config.getNumRecordsUpsert());
                adjustedRDD = this.adjustRDDToGenerateExactNumUpdates(adjustedRDD, this.jsc, config.getNumRecordsUpsert());
            } else {
                deltaInputReader = new DFSHoodieDatasetInputReader(this.jsc, ((DFSDeltaConfig)this.deltaOutputConfig).getDatasetOutputPath(), this.schemaStr);
                adjustedRDD = config.getFractionUpsertPerFile() > 0.0 ? deltaInputReader.read(config.getNumUpsertPartitions(), config.getNumUpsertFiles(), config.getFractionUpsertPerFile()) : deltaInputReader.read(config.getNumUpsertPartitions(), config.getNumUpsertFiles(), config.getNumRecordsUpsert());
            }
            log.info("Repartitioning records");
            adjustedRDD = adjustedRDD.repartition(this.jsc.defaultParallelism().intValue());
            log.info("Repartitioning records done");
            UpdateConverter converter = new UpdateConverter(this.schemaStr, config.getRecordSize(), this.partitionPathFieldNames, this.recordRowKeyFieldNames);
            JavaRDD updates = converter.convert(adjustedRDD);
            log.info("Records converted");
            updates.persist(StorageLevel.DISK_ONLY());
            return inserts != null ? inserts.union(updates) : updates;
        }
        throw new IllegalArgumentException("Other formats are not supported at the moment");
    }

    public Map<Integer, Long> getPartitionToCountMap(JavaRDD<GenericRecord> records) {
        return records.mapPartitionsWithIndex((Function2 & Serializable)(index, itr) -> {
            Iterable newIterable = () -> itr;
            long count = StreamSupport.stream(newIterable.spliterator(), true).count();
            return Arrays.asList(new Tuple2(index, (Object)count)).iterator();
        }, true).mapToPair((PairFunction & Serializable)i -> i).collectAsMap();
    }

    public Map<Integer, Long> getAdjustedPartitionsCount(Map<Integer, Long> partitionCountMap, long recordsToRemove) {
        long remainingRecordsToRemove = recordsToRemove;
        Iterator<Map.Entry<Integer, Long>> iterator = partitionCountMap.entrySet().iterator();
        HashMap<Integer, Long> adjustedPartitionCountMap = new HashMap<Integer, Long>();
        while (iterator.hasNext()) {
            Map.Entry<Integer, Long> entry = iterator.next();
            if (entry.getValue() < remainingRecordsToRemove) {
                remainingRecordsToRemove -= entry.getValue().longValue();
                adjustedPartitionCountMap.put(entry.getKey(), 0L);
            } else {
                long newValue = entry.getValue() - remainingRecordsToRemove;
                remainingRecordsToRemove = 0L;
                adjustedPartitionCountMap.put(entry.getKey(), newValue);
            }
            if (remainingRecordsToRemove != 0L) continue;
            break;
        }
        return adjustedPartitionCountMap;
    }

    public JavaRDD<GenericRecord> adjustRDDToGenerateExactNumUpdates(JavaRDD<GenericRecord> updates, JavaSparkContext jsc, long totalRecordsRequired) {
        Map<Integer, Long> actualPartitionCountMap = this.getPartitionToCountMap(updates);
        long totalRecordsGenerated = actualPartitionCountMap.values().stream().mapToLong(Long::longValue).sum();
        if (this.isSafeToTake(totalRecordsRequired, totalRecordsGenerated)) {
            long sizeOfUpdateRDD = totalRecordsGenerated;
            while (totalRecordsRequired != sizeOfUpdateRDD) {
                long recordsToTake;
                long l = recordsToTake = totalRecordsRequired - sizeOfUpdateRDD > sizeOfUpdateRDD ? sizeOfUpdateRDD : totalRecordsRequired - sizeOfUpdateRDD;
                if (totalRecordsRequired - sizeOfUpdateRDD > recordsToTake && recordsToTake <= sizeOfUpdateRDD) {
                    updates = updates.union(updates);
                    sizeOfUpdateRDD *= 2L;
                    continue;
                }
                List remainingUpdates = updates.take((int)recordsToTake);
                updates = updates.union(jsc.parallelize(remainingUpdates));
                sizeOfUpdateRDD += recordsToTake;
            }
            return updates;
        }
        if (totalRecordsRequired < totalRecordsGenerated) {
            Map<Integer, Long> adjustedPartitionCountMap = this.getAdjustedPartitionsCount(actualPartitionCountMap, totalRecordsGenerated - totalRecordsRequired);
            JavaRDD trimmedRecords = updates.mapPartitionsWithIndex((Function2 & Serializable)(index, itr) -> {
                int counter = 1;
                ArrayList entriesToKeep = new ArrayList();
                if (!adjustedPartitionCountMap.containsKey(index)) {
                    return itr;
                }
                long recordsToKeepForThisPartition = (Long)adjustedPartitionCountMap.get(index);
                while ((long)counter <= recordsToKeepForThisPartition && itr.hasNext()) {
                    entriesToKeep.add(itr.next());
                    ++counter;
                }
                return entriesToKeep.iterator();
            }, true);
            return trimmedRecords;
        }
        return updates;
    }

    private boolean isSafeToTake(long totalRecords, long totalRecordsGenerated) {
        return totalRecords > totalRecordsGenerated;
    }
}

