/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.engine.spark;

import com.google.common.base.Function;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterators;
import com.google.common.collect.PeekingIterator;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.SAMSequenceRecord;
import htsjdk.samtools.util.Locatable;
import htsjdk.samtools.util.OverlapDetector;
import java.io.Serializable;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaRDDLike;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.FlatMapFunction2;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.PartitionCoalescer;
import org.apache.spark.rdd.RDD;
import org.broadinstitute.hellbender.engine.Shard;
import org.broadinstitute.hellbender.engine.ShardBoundary;
import org.broadinstitute.hellbender.engine.spark.RangePartitionCoalescer;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import scala.Option;
import scala.Tuple2;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;

public class SparkSharder {
    public static <L extends Locatable, SB extends ShardBoundary> JavaRDD<Shard<L>> shard(JavaSparkContext ctx, JavaRDD<L> locatables, Class<L> locatableClass, SAMSequenceDictionary sequenceDictionary, List<SB> intervals, int maxLocatableLength) {
        return SparkSharder.shard(ctx, locatables, locatableClass, sequenceDictionary, intervals, maxLocatableLength, false);
    }

    public static <L extends Locatable, SB extends ShardBoundary> JavaRDD<Shard<L>> shard(JavaSparkContext ctx, JavaRDD<L> locatables, Class<L> locatableClass, SAMSequenceDictionary sequenceDictionary, JavaRDD<SB> intervals, int maxLocatableLength) {
        return SparkSharder.shard(ctx, locatables, locatableClass, sequenceDictionary, intervals, maxLocatableLength, false);
    }

    public static <L extends Locatable, SB extends ShardBoundary> JavaRDD<Shard<L>> shard(JavaSparkContext ctx, JavaRDD<L> locatables, Class<L> locatableClass, SAMSequenceDictionary sequenceDictionary, List<SB> intervals, int maxLocatableLength, boolean useShuffle) {
        List paddedIntervals = intervals.stream().map(ShardBoundary::paddedShardBoundary).collect(Collectors.toList());
        if (useShuffle) {
            OverlapDetector overlapDetector = OverlapDetector.create(paddedIntervals);
            Broadcast overlapDetectorBroadcast = ctx.broadcast((Object)overlapDetector);
            JavaPairRDD intervalsToLocatables = locatables.flatMapToPair((PairFlatMapFunction & Serializable)locatable -> {
                Set overlaps = ((OverlapDetector)overlapDetectorBroadcast.getValue()).getOverlaps(locatable);
                return overlaps.stream().map(key -> new Tuple2(key, locatable)).collect(Collectors.toList()).iterator();
            });
            JavaPairRDD grouped = intervalsToLocatables.groupByKey();
            return grouped.map((org.apache.spark.api.java.function.Function & Serializable)value -> ((ShardBoundary)value._1()).createShard((Iterable)value._2()));
        }
        return SparkSharder.joinOverlapping(ctx, locatables, locatableClass, sequenceDictionary, paddedIntervals, maxLocatableLength, new MapFunction<Tuple2<ShardBoundary, Iterable<L>>, Shard<L>>(){
            private static final long serialVersionUID = 1L;

            public Shard<L> call(Tuple2<ShardBoundary, Iterable<L>> value) {
                return ((ShardBoundary)value._1()).createShard((Iterable)value._2());
            }
        });
    }

    private static <L extends Locatable, SB extends ShardBoundary> JavaRDD<Shard<L>> shard(JavaSparkContext ctx, JavaRDD<L> locatables, Class<L> locatableClass, SAMSequenceDictionary sequenceDictionary, JavaRDD<SB> intervals, int maxLocatableLength, boolean useShuffle) {
        JavaRDD paddedIntervals = intervals.map(ShardBoundary::paddedShardBoundary);
        if (useShuffle) {
            throw new UnsupportedOperationException("Shuffle not supported when sharding an RDD of intervals.");
        }
        return SparkSharder.joinOverlapping(ctx, locatables, locatableClass, sequenceDictionary, paddedIntervals, maxLocatableLength, new MapFunction<Tuple2<ShardBoundary, Iterable<L>>, Shard<L>>(){
            private static final long serialVersionUID = 1L;

            public Shard<L> call(Tuple2<ShardBoundary, Iterable<L>> value) {
                return ((ShardBoundary)value._1()).createShard((Iterable)value._2());
            }
        });
    }

    private static <L extends Locatable, I extends Locatable, T> JavaRDD<T> joinOverlapping(JavaSparkContext ctx, JavaRDD<L> locatables, Class<L> locatableClass, SAMSequenceDictionary sequenceDictionary, List<I> intervals, int maxLocatableLength, final MapFunction<Tuple2<I, Iterable<L>>, T> f) {
        return SparkSharder.joinOverlapping(ctx, locatables, locatableClass, sequenceDictionary, intervals, maxLocatableLength, (FlatMapFunction2 & Serializable)(locatablesIterator, shardsIterator) -> Iterators.transform(SparkSharder.locatablesPerShard(locatablesIterator, shardsIterator, sequenceDictionary, maxLocatableLength), (Function)new Function<Tuple2<I, Iterable<L>>, T>(){

            @Nullable
            public T apply(@Nullable Tuple2<I, Iterable<L>> input) {
                try {
                    return f.call(input);
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
        }));
    }

    private static <L extends Locatable, I extends Locatable, T> JavaRDD<T> joinOverlapping(JavaSparkContext ctx, JavaRDD<L> locatables, Class<L> locatableClass, SAMSequenceDictionary sequenceDictionary, JavaRDD<I> intervals, int maxLocatableLength, final MapFunction<Tuple2<I, Iterable<L>>, T> f) {
        return SparkSharder.joinOverlapping(ctx, locatables, locatableClass, sequenceDictionary, intervals, maxLocatableLength, (FlatMapFunction2 & Serializable)(locatablesIterator, shardsIterator) -> Iterators.transform(SparkSharder.locatablesPerShard(locatablesIterator, shardsIterator, sequenceDictionary, maxLocatableLength), (Function)new Function<Tuple2<I, Iterable<L>>, T>(){

            @Nullable
            public T apply(@Nullable Tuple2<I, Iterable<L>> input) {
                try {
                    return f.call(input);
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
        }));
    }

    private static <L extends Locatable, I extends Locatable, T> JavaRDD<T> joinOverlapping(JavaSparkContext ctx, JavaRDD<L> locatables, Class<L> locatableClass, SAMSequenceDictionary sequenceDictionary, List<I> intervals, int maxLocatableLength, FlatMapFunction2<Iterator<L>, Iterator<I>, T> f) {
        return SparkSharder.joinOverlapping(ctx, locatables, locatableClass, sequenceDictionary, ctx.parallelize(intervals), maxLocatableLength, f);
    }

    private static <L extends Locatable, I extends Locatable, T> JavaRDD<T> joinOverlapping(JavaSparkContext ctx, JavaRDD<L> locatables, Class<L> locatableClass, SAMSequenceDictionary sequenceDictionary, JavaRDD<I> intervals, int maxLocatableLength, FlatMapFunction2<Iterator<L>, Iterator<I>, T> f) {
        List<PartitionLocatable<SimpleInterval>> partitionReadExtents = SparkSharder.computePartitionReadExtents(locatables, sequenceDictionary, maxLocatableLength);
        List firstLocatablesList = partitionReadExtents.stream().map(PartitionLocatable::getLocatable).collect(Collectors.toList());
        Broadcast firstLocatablesBroadcast = ctx.broadcast(firstLocatablesList);
        OverlapDetector overlapDetector = OverlapDetector.create(partitionReadExtents);
        Broadcast overlapDetectorBroadcast = ctx.broadcast((Object)overlapDetector);
        JavaRDD indexedIntervals = intervals.map((org.apache.spark.api.java.function.Function & Serializable)interval -> {
            int[] partitionIndexes = ((OverlapDetector)overlapDetectorBroadcast.getValue()).getOverlaps(interval).stream().mapToInt(PartitionLocatable::getPartitionIndex).toArray();
            if (partitionIndexes.length == 0) {
                List firstLocatables = (List)firstLocatablesBroadcast.getValue();
                int i = Collections.binarySearch(firstLocatables, new SimpleInterval((Locatable)interval), (o1, o2) -> IntervalUtils.compareLocatables(o1, o2, sequenceDictionary));
                if (i >= 0) {
                    throw new IllegalStateException();
                }
                int insertionPoint = -i - 1;
                if (insertionPoint == firstLocatables.size()) {
                    insertionPoint = firstLocatables.size() - 1;
                }
                return new PartitionLocatable<Locatable>(insertionPoint, (Locatable)interval);
            }
            Arrays.sort(partitionIndexes);
            int startIndex = partitionIndexes[0];
            int endIndex = partitionIndexes[partitionIndexes.length - 1];
            return new PartitionLocatable<Locatable>(startIndex, endIndex, (Locatable)interval);
        });
        JavaRDD indexedIntervalsRepartitioned = indexedIntervals.mapToPair((PairFunction & Serializable)interval -> new Tuple2(interval, (Object)null)).repartitionAndSortWithinPartitions((Partitioner)new PartitionLocatablePartitioner(locatables.getNumPartitions()), new PartitionLocatableComparator(sequenceDictionary)).keys();
        indexedIntervalsRepartitioned.cache();
        Map maxEndPartitionIndexesMap = indexedIntervalsRepartitioned.mapToPair((PairFunction & Serializable)partitionLocatable -> new Tuple2((Object)partitionLocatable.getPartitionIndex(), (Object)partitionLocatable.getEndPartitionIndex())).reduceByKey(Math::max).collectAsMap();
        List<Integer> maxEndPartitionIndexes = IntStream.range(0, locatables.getNumPartitions()).boxed().collect(Collectors.toList());
        maxEndPartitionIndexesMap.forEach((startIndex, endIndex) -> {
            if (endIndex > (Integer)maxEndPartitionIndexes.get((int)startIndex)) {
                maxEndPartitionIndexes.set((int)startIndex, (Integer)endIndex);
            }
        });
        JavaRDD<L> coalescedRdd = SparkSharder.coalesce(locatables, locatableClass, new RangePartitionCoalescer(maxEndPartitionIndexes));
        return coalescedRdd.zipPartitions((JavaRDDLike)indexedIntervalsRepartitioned.map(PartitionLocatable::getLocatable), f);
    }

    static <L extends Locatable, I extends Locatable> Iterator<Tuple2<I, Iterable<L>>> locatablesPerShard(final Iterator<L> locatables, Iterator<I> shards, final SAMSequenceDictionary sequenceDictionary, final int maxLocatableLength) {
        if (!shards.hasNext()) {
            return Collections.emptyIterator();
        }
        final PeekingIterator peekingShards = Iterators.peekingIterator(shards);
        AbstractIterator iterator = new AbstractIterator<Tuple2<I, Iterable<L>>>(){
            Queue<PendingShard<L, I>> pendingShards = new ArrayDeque();

            protected Tuple2<I, Iterable<L>> computeNext() {
                Tuple2 nextShard = null;
                while (locatables.hasNext() && nextShard == null) {
                    int size;
                    Locatable locatable = (Locatable)locatables.next();
                    if (locatable.getContig() != null && (size = locatable.getEnd() - locatable.getStart() + 1) > maxLocatableLength) {
                        throw new UserException(String.format("Max size of locatable exceeded. Max size is %s, but locatable size is %s. Try increasing shard size and/or padding. Locatable: %s", maxLocatableLength, size, locatable));
                    }
                    while (peekingShards.hasNext() && !IntervalUtils.isAfter((Locatable)peekingShards.peek(), locatable, sequenceDictionary)) {
                        this.pendingShards.add(new PendingShard((Locatable)peekingShards.next()));
                    }
                    for (PendingShard pendingShard : this.pendingShards) {
                        if (!IntervalUtils.overlaps(pendingShard, locatable)) continue;
                        pendingShard.addLocatable(locatable);
                    }
                    if (this.pendingShards.isEmpty() || !IntervalUtils.isAfter(locatable, this.pendingShards.peek(), sequenceDictionary)) continue;
                    nextShard = this.pendingShards.poll().get();
                }
                if (!locatables.hasNext()) {
                    while (peekingShards.hasNext()) {
                        this.pendingShards.add(new PendingShard((Locatable)peekingShards.next()));
                    }
                    if (!this.pendingShards.isEmpty() && nextShard == null) {
                        nextShard = this.pendingShards.poll().get();
                    }
                }
                if (nextShard == null) {
                    return (Tuple2)this.endOfData();
                }
                return nextShard;
            }
        };
        return iterator;
    }

    private static <I extends Locatable, L extends Locatable> boolean toRightOf(I interval, L locatable, SAMSequenceDictionary sequenceDictionary) {
        int locatableContigIndex;
        int intervalContigIndex = sequenceDictionary.getSequenceIndex(interval.getContig());
        return intervalContigIndex == (locatableContigIndex = sequenceDictionary.getSequenceIndex(locatable.getContig())) && interval.getEnd() < locatable.getStart() || intervalContigIndex < locatableContigIndex;
    }

    static <L extends Locatable> List<PartitionLocatable<SimpleInterval>> computePartitionReadExtents(JavaRDD<L> locatables, SAMSequenceDictionary sequenceDictionary, int maxLocatableLength) {
        List allSplitPoints = locatables.mapPartitions((FlatMapFunction & Serializable)it -> ImmutableList.of(new PartitionLocatable<Object>(-1, (it.hasNext() ? (Locatable)it.next() : null))).iterator()).collect();
        ArrayList splitPoints = new ArrayList();
        for (int i = 0; i < allSplitPoints.size(); ++i) {
            Object locatable = ((PartitionLocatable)allSplitPoints.get(i)).getLocatable();
            if (locatable == null) continue;
            splitPoints.add(new PartitionLocatable(i, locatable));
        }
        ArrayList<PartitionLocatable<SimpleInterval>> extents = new ArrayList<PartitionLocatable<SimpleInterval>>();
        for (int i = 0; i < splitPoints.size(); ++i) {
            int nextContigIndex;
            Locatable next;
            PartitionLocatable splitPoint = (PartitionLocatable)splitPoints.get(i);
            int partitionIndex = splitPoint.getPartitionIndex();
            Object current = splitPoint.getLocatable();
            int intervalContigIndex = sequenceDictionary.getSequenceIndex(current.getContig());
            Utils.validate(intervalContigIndex != -1, "Contig not found in sequence dictionary: " + current.getContig());
            if (i < splitPoints.size() - 1) {
                next = (Locatable)splitPoints.get(i + 1);
                nextContigIndex = sequenceDictionary.getSequenceIndex(next.getContig());
                Utils.validate(nextContigIndex != -1, "Contig not found in sequence dictionary: " + next.getContig());
            } else {
                next = null;
                nextContigIndex = sequenceDictionary.getSequences().size();
            }
            if (intervalContigIndex == nextContigIndex) {
                SparkSharder.addPartitionReadExtent(extents, partitionIndex, current.getContig(), current.getStart(), next.getStart() + maxLocatableLength);
                continue;
            }
            SAMSequenceRecord seq = sequenceDictionary.getSequence(current.getContig());
            Utils.validate(seq != null, "Contig not found in sequence dictionary: " + current.getContig());
            int contigEnd = seq.getSequenceLength();
            SparkSharder.addPartitionReadExtent(extents, partitionIndex, current.getContig(), current.getStart(), contigEnd);
            for (int contigIndex = intervalContigIndex + 1; contigIndex < nextContigIndex; ++contigIndex) {
                SAMSequenceRecord sequence = sequenceDictionary.getSequence(contigIndex);
                Utils.validate(sequence != null, "Contig index not found in sequence dictionary: " + contigIndex);
                SparkSharder.addPartitionReadExtent(extents, partitionIndex, sequence.getSequenceName(), 1, sequence.getSequenceLength());
            }
            if (next == null) continue;
            SparkSharder.addPartitionReadExtent(extents, partitionIndex, next.getContig(), 1, next.getStart() + maxLocatableLength);
        }
        return extents;
    }

    private static void addPartitionReadExtent(List<PartitionLocatable<SimpleInterval>> extents, int partitionIndex, String contig, int start, int end) {
        SimpleInterval extent = new SimpleInterval(contig, start, end);
        extents.add(new PartitionLocatable<SimpleInterval>(partitionIndex, extent));
    }

    private static <T> JavaRDD<T> coalesce(JavaRDD<T> rdd, Class<T> cls, PartitionCoalescer partitionCoalescer) {
        RDD coalescedRdd = rdd.rdd().coalesce(rdd.getNumPartitions(), false, Option.apply((Object)partitionCoalescer), null);
        ClassTag tag = ClassTag$.MODULE$.apply(cls);
        return new JavaRDD(coalescedRdd, tag);
    }

    static class PartitionLocatable<L extends Locatable>
    implements Locatable {
        private static final long serialVersionUID = 1L;
        private final int partitionIndex;
        private final int endPartitionIndex;
        private final L interval;

        public PartitionLocatable(int partitionIndex, L interval) {
            this(partitionIndex, partitionIndex, interval);
        }

        public PartitionLocatable(int partitionIndex, int endPartitionIndex, L interval) {
            this.partitionIndex = partitionIndex;
            this.endPartitionIndex = endPartitionIndex;
            this.interval = interval;
        }

        public int getPartitionIndex() {
            return this.partitionIndex;
        }

        public int getEndPartitionIndex() {
            return this.endPartitionIndex;
        }

        public L getLocatable() {
            return this.interval;
        }

        public String getContig() {
            return this.interval.getContig();
        }

        public int getStart() {
            return this.interval.getStart();
        }

        public int getEnd() {
            return this.interval.getEnd();
        }

        public String toString() {
            return "PartitionLocatable{partitionIndex=" + this.partitionIndex + ", interval='" + this.interval + '\'' + '}';
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            PartitionLocatable that = (PartitionLocatable)o;
            if (this.partitionIndex != that.partitionIndex) {
                return false;
            }
            return this.interval.equals(that.interval);
        }

        public int hashCode() {
            int result = this.partitionIndex;
            result = 31 * result + this.interval.hashCode();
            return result;
        }
    }

    private static class PartitionLocatableComparator<L extends Locatable>
    implements Comparator<PartitionLocatable<L>>,
    Serializable {
        private static final long serialVersionUID = 1L;
        private final SAMSequenceDictionary sequenceDictionary;

        private PartitionLocatableComparator(SAMSequenceDictionary sequenceDictionary) {
            this.sequenceDictionary = sequenceDictionary;
        }

        @Override
        public int compare(PartitionLocatable<L> pl1, PartitionLocatable<L> pl2) {
            return IntervalUtils.compareLocatables(pl1.getLocatable(), pl2.getLocatable(), this.sequenceDictionary);
        }
    }

    private static class PartitionLocatablePartitioner
    extends Partitioner {
        private static final long serialVersionUID = 1L;
        private int numPartitions;

        public PartitionLocatablePartitioner(int numPartitions) {
            this.numPartitions = numPartitions;
        }

        public int numPartitions() {
            return this.numPartitions;
        }

        public int getPartition(Object key) {
            return ((PartitionLocatable)key).getPartitionIndex();
        }
    }

    private static class PendingShard<L extends Locatable, I extends Locatable>
    implements Locatable {
        private I interval;
        private List<L> locatables = new ArrayList<L>();

        public PendingShard(I interval) {
            this.interval = interval;
        }

        public void addLocatable(L locatable) {
            this.locatables.add(locatable);
        }

        public String getContig() {
            return this.interval.getContig();
        }

        public int getStart() {
            return this.interval.getStart();
        }

        public int getEnd() {
            return this.interval.getEnd();
        }

        public Tuple2<I, Iterable<L>> get() {
            return new Tuple2(this.interval, this.locatables);
        }
    }
}

