/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.spark.transforms.markduplicates;

import com.esotericsoftware.kryo.DefaultSerializer;
import com.esotericsoftware.kryo.serializers.FieldSerializer;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.LinkedListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMReadGroupRecord;
import htsjdk.samtools.metrics.MetricBase;
import htsjdk.samtools.metrics.MetricsFile;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.metrics.MetricsUtils;
import org.broadinstitute.hellbender.tools.spark.transforms.markduplicates.MarkDuplicatesSpark;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.ReadUtils;
import org.broadinstitute.hellbender.utils.read.SAMRecordToGATKReadAdapter;
import org.broadinstitute.hellbender.utils.read.markduplicates.GATKDuplicationMetrics;
import org.broadinstitute.hellbender.utils.read.markduplicates.LibraryIdGenerator;
import org.broadinstitute.hellbender.utils.read.markduplicates.MarkDuplicatesScoringStrategy;
import org.broadinstitute.hellbender.utils.read.markduplicates.ReadsKey;
import org.broadinstitute.hellbender.utils.read.markduplicates.sparkrecords.EmptyFragment;
import org.broadinstitute.hellbender.utils.read.markduplicates.sparkrecords.Fragment;
import org.broadinstitute.hellbender.utils.read.markduplicates.sparkrecords.MarkDuplicatesSparkRecord;
import org.broadinstitute.hellbender.utils.read.markduplicates.sparkrecords.Pair;
import org.broadinstitute.hellbender.utils.read.markduplicates.sparkrecords.PairedEnds;
import org.broadinstitute.hellbender.utils.read.markduplicates.sparkrecords.Passthrough;
import org.broadinstitute.hellbender.utils.read.markduplicates.sparkrecords.TransientFieldPhysicalLocation;
import org.broadinstitute.hellbender.utils.spark.SparkUtils;
import picard.analysis.MergeableMetricBase;
import picard.sam.markduplicates.util.OpticalDuplicateFinder;
import picard.sam.util.PhysicalLocation;
import scala.Tuple2;

public class MarkDuplicatesSparkUtils {
    public static final String OPTICAL_DUPLICATE_TOTAL_ATTRIBUTE_NAME = "OD";
    private static final Comparator<TransientFieldPhysicalLocation> PAIRED_ENDS_SCORE_COMPARATOR = Comparator.comparing(PairedEnds::getScore).thenComparing(TransientFieldPhysicalLocationComparator.INSTANCE.reversed());

    public static String getLibraryForRead(GATKRead read, SAMFileHeader header, String defaultLibrary) {
        SAMReadGroupRecord readGroup = ReadUtils.getSAMReadGroupRecord(read, header);
        if (readGroup != null) {
            String library = readGroup.getLibrary();
            return library == null ? defaultLibrary : library;
        }
        if (read.getReadGroup() == null) {
            throw new UserException.ReadMissingReadGroup(read);
        }
        throw new UserException.HeaderMissingReadGroup(read);
    }

    static JavaPairRDD<IndexPair<String>, Integer> transformToDuplicateNames(SAMFileHeader header, MarkDuplicatesScoringStrategy scoringStrategy, OpticalDuplicateFinder finder, JavaRDD<GATKRead> reads, int numReducers, boolean markOpticalDups) {
        JavaRDD mappedReads = reads.filter(ReadFilterLibrary.MAPPED::test);
        JavaPairRDD<String, Iterable<IndexPair<GATKRead>>> keyedReads = MarkDuplicatesSparkUtils.getReadsGroupedByName(header, (JavaRDD<GATKRead>)mappedReads, numReducers);
        Broadcast headerReadGroupIndexMap = JavaSparkContext.fromSparkContext((SparkContext)reads.context()).broadcast(MarkDuplicatesSparkUtils.getHeaderReadGroupIndexMap(header));
        Broadcast libraryIndex = JavaSparkContext.fromSparkContext((SparkContext)reads.context()).broadcast(MarkDuplicatesSparkUtils.constructLibraryIndex(header));
        JavaPairRDD pairedEnds = keyedReads.flatMapToPair((PairFlatMapFunction & Serializable)keyedRead -> {
            ArrayList out = Lists.newArrayList();
            IndexPair[] hadNonPrimaryRead = new IndexPair[]{null};
            List primaryReads = Utils.stream((Iterable)keyedRead._2()).peek(readWithIndex -> {
                GATKRead read = (GATKRead)readWithIndex.getValue();
                if (!read.isSecondaryAlignment() && !read.isSupplementaryAlignment()) {
                    EmptyFragment fragment = ReadUtils.readHasMappedMate(read) ? MarkDuplicatesSparkRecord.newEmptyFragment(read, header, (Map)libraryIndex.getValue()) : MarkDuplicatesSparkRecord.newFragment(read, header, readWithIndex.getIndex(), scoringStrategy, (Map)libraryIndex.getValue());
                    out.add(new Tuple2((Object)((MarkDuplicatesSparkRecord)fragment).key(), (Object)fragment));
                } else {
                    hadNonPrimaryRead[0] = readWithIndex;
                }
            }).filter(indexPair -> !((GATKRead)indexPair.getValue()).isSecondaryAlignment() && !((GATKRead)indexPair.getValue()).isSupplementaryAlignment()).collect(Collectors.toList());
            if (primaryReads.isEmpty()) {
                Passthrough pass = MarkDuplicatesSparkRecord.getPassthrough((GATKRead)hadNonPrimaryRead[0].getValue(), hadNonPrimaryRead[0].getIndex());
                out.add(new Tuple2((Object)((MarkDuplicatesSparkRecord)pass).key(), (Object)pass));
                return out.iterator();
            }
            if (primaryReads.size() > 2) {
                throw new UserException.UnimplementedFeature(String.format("MarkDuplicatesSpark only supports singleton fragments and pairs. We found the following group with >2 primary reads: ( %d number of reads). \n%s.", primaryReads.size(), primaryReads.stream().map(Object::toString).collect(Collectors.joining("\n"))));
            }
            List mappedPair = primaryReads.stream().filter(readWithIndex -> ReadUtils.readHasMappedMate((GATKRead)readWithIndex.getValue())).collect(Collectors.toList());
            if (mappedPair.size() == 2) {
                GATKRead firstRead = (GATKRead)((IndexPair)mappedPair.get(0)).getValue();
                IndexPair secondRead = (IndexPair)mappedPair.get(1);
                Pair pair = MarkDuplicatesSparkRecord.newPair(firstRead, (GATKRead)secondRead.getValue(), header, secondRead.getIndex(), scoringStrategy, (Map)libraryIndex.getValue());
                Short readGroup = (Short)((Map)headerReadGroupIndexMap.getValue()).get(firstRead.getReadGroup());
                if (readGroup == null) {
                    throw firstRead.getReadGroup() == null ? new UserException.ReadMissingReadGroup(firstRead) : new UserException.HeaderMissingReadGroup(firstRead);
                }
                pair.setReadGroup(readGroup);
                out.add(new Tuple2((Object)pair.key(), (Object)pair));
            } else if (mappedPair.size() == 1) {
                IndexPair firstRead = (IndexPair)mappedPair.get(0);
                Passthrough pass = MarkDuplicatesSparkRecord.getPassthrough((GATKRead)firstRead.getValue(), firstRead.getIndex());
                out.add(new Tuple2((Object)((MarkDuplicatesSparkRecord)pass).key(), (Object)pass));
            }
            return out.iterator();
        });
        JavaPairRDD keyedPairs = pairedEnds.groupByKey();
        return MarkDuplicatesSparkUtils.markDuplicateRecords((JavaPairRDD<ReadsKey, Iterable<MarkDuplicatesSparkRecord>>)keyedPairs, finder, markOpticalDups);
    }

    public static Map<String, Byte> constructLibraryIndex(SAMFileHeader header) {
        List discoveredLibraries = header.getReadGroups().stream().map(r -> {
            String library = r.getLibrary();
            return library == null ? "Unknown Library" : library;
        }).distinct().collect(Collectors.toList());
        if (discoveredLibraries.size() > 255) {
            throw new GATKException("Detected too many read libraries among read groups header, currently MarkDuplicatesSpark only supports up to 256 unique readgroup libraries but " + discoveredLibraries.size() + " were found");
        }
        Iterator iterator = IntStream.range(0, discoveredLibraries.size()).boxed().map(Integer::byteValue).iterator();
        return Maps.uniqueIndex(iterator, idx -> (String)discoveredLibraries.get(idx.byteValue()));
    }

    private static Map<String, Short> getHeaderReadGroupIndexMap(SAMFileHeader header) {
        List readGroups = header.getReadGroups();
        if (readGroups.size() > 65535) {
            throw new GATKException("Detected too many read groups in the header, currently MarkDuplicatesSpark only supports up to 65535 unique readgroup IDs but " + readGroups.size() + " were found");
        }
        if (readGroups.size() == 0) {
            throw new UserException.BadInput("Sam file header missing Read Group fields. MarkDuplicatesSpark currently requires reads to be labeled with read group tags, please add read groups tags to your reads");
        }
        Iterator iterator = IntStream.range(0, readGroups.size()).boxed().map(Integer::shortValue).iterator();
        return Maps.uniqueIndex(iterator, idx -> ((SAMReadGroupRecord)readGroups.get(idx.shortValue())).getId());
    }

    private static JavaPairRDD<String, Iterable<IndexPair<GATKRead>>> getReadsGroupedByName(SAMFileHeader header, JavaRDD<GATKRead> reads, int numReducers) {
        JavaRDD indexedReads = reads.mapPartitionsWithIndex((Function2 & Serializable)(index, iter) -> Utils.stream(iter).map(read -> {
            if (read.getClass() != SAMRecordToGATKReadAdapter.class) {
                throw new GATKException(String.format("MarkDuplicatesSpark currently only supports SAMRecords as an underlying reads data source class, %s found instead", read.getClass().toString()));
            }
            return new IndexPair<GATKRead>((GATKRead)read, (int)index);
        }).iterator(), false);
        if (!ReadUtils.isReadNameGroupedBam(header)) {
            throw new GATKException(String.format("MarkDuplicatesSparkUtils.mark() requires input reads to be queryname sorted or querygrouped, yet the header indicated it was in %s order instead", header.getSortOrder()));
        }
        JavaPairRDD<String, Iterable<IndexPair<GATKRead>>> keyedReads = MarkDuplicatesSparkUtils.spanReadsByKey((JavaRDD<IndexPair<GATKRead>>)indexedReads);
        return keyedReads;
    }

    private static JavaPairRDD<String, Iterable<IndexPair<GATKRead>>> spanReadsByKey(JavaRDD<IndexPair<GATKRead>> reads) {
        JavaPairRDD nameReadPairs = reads.mapToPair((PairFunction & Serializable)read -> new Tuple2((Object)((GATKRead)read.getValue()).getName(), read));
        return SparkUtils.spanByKey(nameReadPairs).flatMapToPair((PairFlatMapFunction & Serializable)namedRead -> {
            ArrayList out = Lists.newArrayList();
            LinkedListMultimap multi = LinkedListMultimap.create();
            for (IndexPair read : (Iterable)namedRead._2()) {
                multi.put((Object)ReadsKey.keyForRead((GATKRead)read.getValue()), (Object)read);
            }
            for (String key : multi.keySet()) {
                out.add(new Tuple2((Object)key, (Object)Lists.newArrayList((Iterable)multi.get((Object)key))));
            }
            return out.iterator();
        });
    }

    private static JavaPairRDD<IndexPair<String>, Integer> markDuplicateRecords(JavaPairRDD<ReadsKey, Iterable<MarkDuplicatesSparkRecord>> keyedPairs, OpticalDuplicateFinder finder, boolean markOpticalDups) {
        return keyedPairs.flatMapToPair((PairFlatMapFunction & Serializable)keyedPair -> {
            Iterable pairGroups = (Iterable)keyedPair._2();
            ArrayList nonDuplicates = Lists.newArrayList();
            Map<MarkDuplicatesSparkRecord.Type, List<MarkDuplicatesSparkRecord>> stratifiedByType = MarkDuplicatesSparkUtils.splitByType(pairGroups);
            List<MarkDuplicatesSparkRecord> emptyFragments = stratifiedByType.get((Object)MarkDuplicatesSparkRecord.Type.EMPTY_FRAGMENT);
            List<MarkDuplicatesSparkRecord> fragments = stratifiedByType.get((Object)MarkDuplicatesSparkRecord.Type.FRAGMENT);
            List<MarkDuplicatesSparkRecord> pairs = stratifiedByType.get((Object)MarkDuplicatesSparkRecord.Type.PAIR);
            List<MarkDuplicatesSparkRecord> passthroughs = stratifiedByType.get((Object)MarkDuplicatesSparkRecord.Type.PASSTHROUGH);
            if (Utils.isNonEmpty(fragments) && !Utils.isNonEmpty(emptyFragments)) {
                Tuple2<IndexPair<String>, Integer> bestFragment = MarkDuplicatesSparkUtils.handleFragments(fragments, finder);
                nonDuplicates.add(bestFragment);
            }
            if (Utils.isNonEmpty(pairs)) {
                nonDuplicates.addAll(MarkDuplicatesSparkUtils.handlePairs(pairs, finder, markOpticalDups));
            }
            if (Utils.isNonEmpty(passthroughs)) {
                nonDuplicates.addAll(MarkDuplicatesSparkUtils.handlePassthroughs(passthroughs));
            }
            return nonDuplicates.iterator();
        });
    }

    private static Map<MarkDuplicatesSparkRecord.Type, List<MarkDuplicatesSparkRecord>> splitByType(Iterable<MarkDuplicatesSparkRecord> duplicateGroup) {
        EnumMap<MarkDuplicatesSparkRecord.Type, List<MarkDuplicatesSparkRecord>> byType = new EnumMap<MarkDuplicatesSparkRecord.Type, List<MarkDuplicatesSparkRecord>>(MarkDuplicatesSparkRecord.Type.class);
        for (MarkDuplicatesSparkRecord pair : duplicateGroup) {
            byType.compute(pair.getType(), (key, value) -> {
                if (value == null) {
                    ArrayList<MarkDuplicatesSparkRecord> pairedEnds = new ArrayList<MarkDuplicatesSparkRecord>();
                    pairedEnds.add(pair);
                    return pairedEnds;
                }
                value.add(pair);
                return value;
            });
        }
        return byType;
    }

    private static List<Tuple2<IndexPair<String>, Integer>> handlePassthroughs(List<MarkDuplicatesSparkRecord> passthroughs) {
        return passthroughs.stream().map(pair -> new Tuple2(new IndexPair<String>(pair.getName(), pair.getPartitionIndex()), (Object)MarkDuplicatesSpark.NO_OPTICAL_MARKER)).collect(Collectors.toList());
    }

    private static List<Tuple2<IndexPair<String>, Integer>> handlePairs(List<Pair> pairs, OpticalDuplicateFinder finder, boolean markOpticalDups) {
        int numOpticalDuplicates;
        if (pairs.size() == 1) {
            return Collections.singletonList(new Tuple2(new IndexPair<String>(pairs.get(0).getName(), pairs.get(0).getPartitionIndex()), (Object)0));
        }
        ArrayList<Tuple2<IndexPair<String>, Integer>> output = new ArrayList<Tuple2<IndexPair<String>, Integer>>();
        Pair bestPair = (Pair)pairs.stream().peek(pair -> finder.addLocationInformation(pair.getName(), (PhysicalLocation)pair)).max(PAIRED_ENDS_SCORE_COMPARATOR).orElseThrow(() -> new GATKException.ShouldNeverReachHereException("There was no best pair because the stream was empty, but it shouldn't have been empty."));
        Map<Byte, List<Pair>> groupByOrientation = pairs.stream().collect(Collectors.groupingBy(Pair::getOrientationForOpticalDuplicates));
        if (groupByOrientation.containsKey((byte)3) && groupByOrientation.containsKey((byte)5)) {
            ArrayList<Pair> peFR = new ArrayList<Pair>((Collection)groupByOrientation.get((byte)3));
            ArrayList<Pair> peRF = new ArrayList<Pair>((Collection)groupByOrientation.get((byte)5));
            numOpticalDuplicates = MarkDuplicatesSparkUtils.countOpticalDuplicates(finder, peFR, bestPair, markOpticalDups ? output : null) + MarkDuplicatesSparkUtils.countOpticalDuplicates(finder, peRF, bestPair, markOpticalDups ? output : null);
        } else {
            numOpticalDuplicates = MarkDuplicatesSparkUtils.countOpticalDuplicates(finder, pairs, bestPair, markOpticalDups ? output : null);
        }
        output.add((Tuple2<IndexPair<String>, Integer>)new Tuple2(new IndexPair<String>(bestPair.getName(), bestPair.getPartitionIndex()), (Object)numOpticalDuplicates));
        return output;
    }

    private static int countOpticalDuplicates(OpticalDuplicateFinder finder, List<Pair> scored, Pair best, List<Tuple2<IndexPair<String>, Integer>> opticalDuplicateList) {
        boolean[] opticalDuplicateFlags = finder.findOpticalDuplicates(scored, (PhysicalLocation)best);
        int numOpticalDuplicates = 0;
        for (int i = 0; i < opticalDuplicateFlags.length; ++i) {
            if (!opticalDuplicateFlags[i]) continue;
            ++numOpticalDuplicates;
            if (opticalDuplicateList == null) continue;
            opticalDuplicateList.add((Tuple2<IndexPair<String>, Integer>)new Tuple2(new IndexPair<String>(scored.get(i).getName(), scored.get(i).getPartitionIndex()), (Object)MarkDuplicatesSpark.OPTICAL_DUPLICATE_MARKER));
        }
        return numOpticalDuplicates;
    }

    private static Tuple2<IndexPair<String>, Integer> handleFragments(List<MarkDuplicatesSparkRecord> duplicateFragmentGroup, OpticalDuplicateFinder finder) {
        return duplicateFragmentGroup.stream().map(f -> (Fragment)f).peek(f -> finder.addLocationInformation(f.getName(), (PhysicalLocation)f)).max(PAIRED_ENDS_SCORE_COMPARATOR).map(best -> new Tuple2(new IndexPair<String>(best.getName(), best.getPartitionIndex()), (Object)-1)).orElse(null);
    }

    static JavaPairRDD<String, GATKDuplicationMetrics> generateMetrics(SAMFileHeader header, JavaRDD<GATKRead> reads) {
        return reads.mapToPair((PairFunction & Serializable)read -> {
            String library = LibraryIdGenerator.getLibraryName(header, read.getReadGroup());
            GATKDuplicationMetrics metrics = new GATKDuplicationMetrics();
            metrics.LIBRARY = library;
            metrics.updateMetrics((GATKRead)read);
            if (read.getTransientAttribute(OPTICAL_DUPLICATE_TOTAL_ATTRIBUTE_NAME) != null) {
                metrics.READ_PAIR_OPTICAL_DUPLICATES += (long)((Integer)read.getTransientAttribute(OPTICAL_DUPLICATE_TOTAL_ATTRIBUTE_NAME)).intValue();
            }
            return new Tuple2((Object)library, (Object)metrics);
        }).foldByKey((Object)new GATKDuplicationMetrics(), (Function2 & Serializable)(metricsSum, m) -> {
            metricsSum.merge((MergeableMetricBase)m);
            if (!metricsSum.LIBRARY.equals(m.LIBRARY)) {
                throw new GATKException("Two different libraries encountered while summing metrics: " + metricsSum.LIBRARY + " and " + m.LIBRARY);
            }
            return metricsSum;
        }).mapValues((Function & Serializable)metrics -> {
            GATKDuplicationMetrics copy = metrics.copy();
            copy.READ_PAIRS_EXAMINED = metrics.READ_PAIRS_EXAMINED / 2L;
            copy.READ_PAIR_DUPLICATES = metrics.READ_PAIR_DUPLICATES / 2L;
            copy.calculateDerivedFields();
            if (copy.ESTIMATED_LIBRARY_SIZE == null) {
                copy.ESTIMATED_LIBRARY_SIZE = 0L;
            }
            return copy;
        });
    }

    public static void saveMetricsRDD(MetricsFile<GATKDuplicationMetrics, Double> result, SAMFileHeader header, JavaPairRDD<String, GATKDuplicationMetrics> metricsRDD, String metricsOutputPath) {
        LibraryIdGenerator libraryIdGenerator = new LibraryIdGenerator(header);
        Map nonEmptyMetricsByLibrary = metricsRDD.collectAsMap();
        Map<String, GATKDuplicationMetrics> emptyMapByLibrary = libraryIdGenerator.getMetricsByLibraryMap();
        ArrayList<? super String> sortedListOfLibraryNames = new ArrayList<String>((Collection<? super String>)Sets.union(emptyMapByLibrary.keySet(), nonEmptyMetricsByLibrary.keySet()));
        sortedListOfLibraryNames.sort(Utils.COMPARE_STRINGS_NULLS_FIRST);
        for (String string : sortedListOfLibraryNames) {
            GATKDuplicationMetrics metricsToAdd = nonEmptyMetricsByLibrary.containsKey(string) ? (GATKDuplicationMetrics)nonEmptyMetricsByLibrary.get(string) : emptyMapByLibrary.get(string);
            metricsToAdd.calculateDerivedFields();
            result.addMetric((MetricBase)metricsToAdd);
        }
        if (nonEmptyMetricsByLibrary.size() == 1) {
            result.setHistogram(((GATKDuplicationMetrics)nonEmptyMetricsByLibrary.values().iterator().next()).calculateRoiHistogram());
        }
        MetricsUtils.saveMetrics(result, metricsOutputPath);
    }

    public static final class TransientFieldPhysicalLocationComparator
    implements Comparator<TransientFieldPhysicalLocation>,
    Serializable {
        private static final long serialVersionUID = 1L;
        public static final TransientFieldPhysicalLocationComparator INSTANCE = new TransientFieldPhysicalLocationComparator();

        private TransientFieldPhysicalLocationComparator() {
        }

        @Override
        public int compare(TransientFieldPhysicalLocation first, TransientFieldPhysicalLocation second) {
            int result = 0;
            if (first.isRead1ReverseStrand() != second.isRead1ReverseStrand()) {
                return first.isRead1ReverseStrand() ? -1 : 1;
            }
            if (first.getTile() != second.getTile()) {
                return first.getTile() - second.getTile();
            }
            if (first.getX() != second.getX()) {
                return first.getX() - second.getX();
            }
            if (first.getY() != second.getY()) {
                return first.getY() - second.getY();
            }
            if (first.getName() != null && second.getName() != null) {
                result = first.getName().compareTo(second.getName());
            }
            return result;
        }
    }

    @DefaultSerializer(value=FieldSerializer.class)
    public static class IndexPair<T> {
        private final T value;
        private final int index;

        public T getValue() {
            return this.value;
        }

        public int getIndex() {
            return this.index;
        }

        @VisibleForTesting
        IndexPair(T value, int index) {
            this.value = value;
            this.index = index;
        }

        public String toString() {
            return "indexpair[" + this.index + "," + this.value.toString() + "]";
        }
    }
}

