/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.utils.spark;

import com.google.common.collect.AbstractIterator;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.collect.PeekingIterator;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceRecord;
import htsjdk.samtools.SAMTextHeaderCodec;
import htsjdk.samtools.util.BinaryCodec;
import htsjdk.samtools.util.BlockCompressedOutputStream;
import htsjdk.samtools.util.BlockCompressedStreamConstants;
import htsjdk.samtools.util.RuntimeIOException;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.Serializable;
import java.io.StringWriter;
import java.io.Writer;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaRDDLike;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.FlatMapFunction2;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.ReadCoordinateComparator;
import org.broadinstitute.hellbender.utils.read.ReadQueryNameComparator;
import org.broadinstitute.hellbender.utils.read.ReadUtils;
import scala.Tuple2;

public final class SparkUtils {
    private static final Logger logger = LogManager.getLogger(SparkUtils.class);

    public static <T> void destroyBroadcast(Broadcast<T> broadcast, String whatBroadcast) {
        try {
            broadcast.destroy();
        }
        catch (Exception e) {
            logger.warn("Failed to destroy broadcast for " + whatBroadcast, (Throwable)e);
        }
    }

    private SparkUtils() {
    }

    public static void convertHeaderlessHadoopBamShardToBam(File bamShard, SAMFileHeader header, File destination) {
        try (FileOutputStream outStream = new FileOutputStream(destination);){
            SparkUtils.writeBAMHeaderToStream(header, outStream);
            FileUtils.copyFile((File)bamShard, (OutputStream)outStream);
            outStream.write(BlockCompressedStreamConstants.EMPTY_GZIP_BLOCK);
        }
        catch (IOException e) {
            throw new UserException("Error writing to " + destination.getAbsolutePath(), e);
        }
    }

    private static void writeBAMHeaderToStream(SAMFileHeader samFileHeader, OutputStream outputStream) {
        BlockCompressedOutputStream blockCompressedOutputStream = new BlockCompressedOutputStream(outputStream, (File)null);
        BinaryCodec outputBinaryCodec = new BinaryCodec((OutputStream)new DataOutputStream((OutputStream)blockCompressedOutputStream));
        StringWriter stringWriter = new StringWriter();
        new SAMTextHeaderCodec().encode((Writer)stringWriter, samFileHeader, true);
        String headerString = ((Object)stringWriter).toString();
        outputBinaryCodec.writeBytes(ReadUtils.BAM_MAGIC);
        outputBinaryCodec.writeString(headerString, true, false);
        outputBinaryCodec.writeInt(samFileHeader.getSequenceDictionary().size());
        for (SAMSequenceRecord sequenceRecord : samFileHeader.getSequenceDictionary().getSequences()) {
            outputBinaryCodec.writeString(sequenceRecord.getSequenceName(), true, true);
            outputBinaryCodec.writeInt(sequenceRecord.getSequenceLength());
        }
        try {
            blockCompressedOutputStream.flush();
        }
        catch (IOException ioe) {
            throw new RuntimeIOException((Throwable)ioe);
        }
    }

    public static boolean hadoopPathExists(JavaSparkContext ctx, URI targetURI) {
        Utils.nonNull(ctx);
        Utils.nonNull(targetURI);
        try {
            Path targetHadoopPath = new Path(targetURI);
            FileSystem fs = targetHadoopPath.getFileSystem(ctx.hadoopConfiguration());
            return fs.exists(targetHadoopPath);
        }
        catch (IOException e) {
            throw new UserException("Error validating existence of path " + targetURI + ": " + e.getMessage());
        }
    }

    public static JavaRDD<GATKRead> sortReadsAccordingToHeader(JavaRDD<GATKRead> reads, SAMFileHeader header, int numReducers) {
        SAMFileHeader.SortOrder order = header.getSortOrder();
        switch (order) {
            case coordinate: {
                return SparkUtils.sortUsingElementsAsKeys(reads, new ReadCoordinateComparator(header), numReducers);
            }
            case queryname: {
                JavaRDD<GATKRead> sortedReads = SparkUtils.sortUsingElementsAsKeys(reads, new ReadQueryNameComparator(), numReducers);
                return SparkUtils.putReadsWithTheSameNameInTheSamePartition(header, sortedReads, JavaSparkContext.fromSparkContext((SparkContext)reads.context()));
            }
        }
        throw new GATKException("Sort order: " + order + " is not supported.");
    }

    public static <T> JavaRDD<T> sortUsingElementsAsKeys(JavaRDD<T> elements, Comparator<T> comparator, int numReducers) {
        Utils.nonNull(comparator);
        Utils.nonNull(elements);
        JavaPairRDD rddReadPairs = elements.mapToPair((PairFunction & Serializable)read -> new Tuple2(read, (Object)null));
        JavaPairRDD readVoidPairs = numReducers > 0 ? rddReadPairs.sortByKey(comparator, true, numReducers) : rddReadPairs.sortByKey(comparator);
        return readVoidPairs.keys();
    }

    public static JavaRDD<GATKRead> putReadsWithTheSameNameInTheSamePartition(SAMFileHeader header, JavaRDD<GATKRead> reads, JavaSparkContext ctx) {
        Utils.validateArg(ReadUtils.isReadNameGroupedBam(header), () -> "Reads must be queryname grouped or sorted. Actual sort:" + header.getSortOrder() + "  Actual grouping:" + header.getGroupOrder());
        List firstReadNameGroupInEachPartition = reads.mapPartitions((FlatMapFunction & Serializable)it -> {
            GATKRead read;
            if (!it.hasNext()) {
                return Iterators.singletonIterator(Collections.emptyList());
            }
            ArrayList<GATKRead> firstGroup = new ArrayList<GATKRead>(2);
            GATKRead firstRead = (GATKRead)it.next();
            firstGroup.add(firstRead);
            String groupName = firstRead.getName();
            while (it.hasNext() && groupName.equals((read = (GATKRead)it.next()).getName())) {
                firstGroup.add(read);
            }
            return Iterators.singletonIterator(firstGroup);
        }).collect();
        int numPartitions = reads.getNumPartitions();
        ArrayList firstGroupFromNextPartition = new ArrayList(firstReadNameGroupInEachPartition.subList(1, numPartitions));
        firstGroupFromNextPartition.add(Collections.emptyList());
        block0: for (int idx2 = numPartitions - 1; idx2 >= 1; --idx2) {
            List curGroup = (List)firstGroupFromNextPartition.get(idx2);
            if (curGroup.isEmpty()) continue;
            String groupName = ((GATKRead)curGroup.get(0)).getName();
            int idx22 = idx2;
            while (--idx22 >= 0) {
                List prevGroup = (List)firstGroupFromNextPartition.get(idx22);
                if (prevGroup.isEmpty()) continue;
                if (!groupName.equals(((GATKRead)prevGroup.get(0)).getName())) continue block0;
                prevGroup.addAll(curGroup);
                curGroup.clear();
                continue block0;
            }
        }
        int[] firstGroupSizes = firstReadNameGroupInEachPartition.stream().mapToInt(List::size).toArray();
        firstGroupSizes[0] = 0;
        JavaRDD readsSansFirstGroup = reads.mapPartitionsWithIndex((Function2 & Serializable)(idx, itr) -> {
            int groupSize = firstGroupSizes[idx];
            while (itr.hasNext() && groupSize-- > 0) {
                itr.next();
            }
            return itr;
        }, true);
        return readsSansFirstGroup.zipPartitions((JavaRDDLike)ctx.parallelize(firstGroupFromNextPartition, numPartitions), (FlatMapFunction2 & Serializable)(it1, it2) -> Iterators.concat((Iterator)it1, ((List)it2.next()).iterator()));
    }

    public static <K, V> JavaPairRDD<K, Iterable<V>> spanByKey(JavaPairRDD<K, V> rdd) {
        return rdd.mapPartitionsToPair(SparkUtils::getSpanningIterator);
    }

    public static <K, V> Iterator<Tuple2<K, Iterable<V>>> getSpanningIterator(Iterator<Tuple2<K, V>> iterator) {
        final PeekingIterator iter = Iterators.peekingIterator(iterator);
        return new AbstractIterator<Tuple2<K, Iterable<V>>>(){

            protected Tuple2<K, Iterable<V>> computeNext() {
                Object key = null;
                ArrayList group = Lists.newArrayList();
                while (iter.hasNext()) {
                    if (key == null) {
                        Tuple2 next = (Tuple2)iter.next();
                        key = next._1();
                        Object value = next._2();
                        group.add(value);
                        continue;
                    }
                    Object nextKey = ((Tuple2)iter.peek())._1();
                    if (nextKey.equals(key)) {
                        group.add(((Tuple2)iter.next())._2());
                        continue;
                    }
                    return new Tuple2(key, (Object)group);
                }
                if (key != null) {
                    return new Tuple2(key, (Object)group);
                }
                return (Tuple2)this.endOfData();
            }
        };
    }

    public static JavaRDD<GATKRead> querynameSortReadsIfNecessary(JavaRDD<GATKRead> reads, int numReducers, SAMFileHeader header) {
        JavaRDD<GATKRead> sortedReadsForMarking;
        if (ReadUtils.isReadNameGroupedBam(header)) {
            sortedReadsForMarking = reads;
        } else {
            header.setSortOrder(SAMFileHeader.SortOrder.queryname);
            sortedReadsForMarking = SparkUtils.sortReadsAccordingToHeader(reads, header, numReducers);
        }
        return sortedReadsForMarking;
    }
}

