/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.spark.validation;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import htsjdk.samtools.SAMFileHeader;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.BetaFeature;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.engine.GATKPath;
import org.broadinstitute.hellbender.engine.TraversalParameters;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.spark.transforms.markduplicates.MarkDuplicatesSparkUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.ReadCoordinateComparator;
import org.broadinstitute.hellbender.utils.read.ReadUtils;
import org.broadinstitute.hellbender.utils.read.markduplicates.ReadsKey;
import picard.cmdline.programgroups.DiagnosticsAndQCProgramGroup;
import scala.Tuple2;

@DocumentedFeature
@CommandLineProgramProperties(summary="Determine if two potentially identical BAMs have the same duplicate reads. This tool is useful for checking if two BAMs that seem identical have the same reads marked as duplicates.", oneLineSummary="Determine if two potentially identical BAMs have the same duplicate reads", programGroup=DiagnosticsAndQCProgramGroup.class)
@BetaFeature
public final class CompareDuplicatesSpark
extends GATKSparkTool {
    private static final long serialVersionUID = 1L;
    public static final String INPUT_2_LONG_NAME = "input2";
    public static final String INPUT_2_SHORT_NAME = "I2";
    public static final String PRINT_SUMMARY_LONG_NAME = "print-summary";
    public static final String THROW_ON_DIFF_LONG_NAME = "throw-on-diff";
    @Argument(doc="The second BAM", shortName="I2", fullName="input2", optional=false)
    protected String input2;
    @Argument(doc="Print a summary", fullName="print-summary", optional=true)
    protected boolean printSummary = true;
    @Argument(doc="Throw error if any differences were found", fullName="throw-on-diff", optional=true)
    protected boolean throwOnDiff = false;
    @Argument(doc="If output is given, the tool will return a bam with all the mismatching duplicate groups in the first specified file", shortName="O", fullName="output", optional=true)
    protected String output;
    @Argument(doc="If output is given, the tool will return a bam with all the mismatching duplicate groups in the second specified input file", shortName="O2", fullName="output2", optional=true)
    protected String output2;

    @Override
    public boolean requiresReads() {
        return true;
    }

    @Override
    public List<ReadFilter> getDefaultReadFilters() {
        return Collections.singletonList(ReadFilterLibrary.ALLOW_ALL_READS);
    }

    @Override
    protected void runTool(JavaSparkContext ctx) {
        JavaRDD tagged;
        Map tagCountMap;
        long secondBamSize;
        if (this.hasOutputSpecified() && this.output == null | this.output2 == null) {
            throw new IllegalArgumentException("Arguments '--output' and '--output2' must both be specified together in order to write mismatch bams or not at all");
        }
        JavaRDD<GATKRead> firstReads = CompareDuplicatesSpark.removeNonReadGroupAttributes(this.getReads());
        ReadsSparkSource readsSource2 = new ReadsSparkSource(ctx, this.readArguments.getReadValidationStringency());
        TraversalParameters traversalParameters = this.hasUserSuppliedIntervals() ? this.intervalArgumentCollection.getTraversalParameters(this.getHeaderForReads().getSequenceDictionary()) : null;
        JavaRDD<GATKRead> secondReads = CompareDuplicatesSpark.removeNonReadGroupAttributes(readsSource2.getParallelReads(new GATKPath(this.input2), null, traversalParameters, this.bamPartitionSplitSize, this.useNio));
        long firstBamSize = firstReads.count();
        if (firstBamSize != (secondBamSize = secondReads.count())) {
            throw new UserException("input bams have different numbers of mapped reads: " + firstBamSize + "," + secondBamSize);
        }
        System.out.println("processing bams with " + firstBamSize + " mapped reads");
        long firstDupesCount = firstReads.filter(GATKRead::isDuplicate).count();
        long secondDupesCount = secondReads.filter(GATKRead::isDuplicate).count();
        if (firstDupesCount != secondDupesCount) {
            System.out.println("BAMs have different number of total duplicates: " + firstDupesCount + "," + secondDupesCount);
        }
        System.out.println("first and second: " + firstDupesCount + "," + secondDupesCount);
        Broadcast libraryIndex = ctx.broadcast(MarkDuplicatesSparkUtils.constructLibraryIndex(this.getHeaderForReads()));
        Broadcast bHeader = ctx.broadcast((Object)this.getHeaderForReads());
        JavaPairRDD firstKeyed = firstReads.mapToPair((PairFunction & Serializable)read -> new Tuple2((Object)ReadsKey.getKeyForFragment(ReadUtils.getStrandedUnclippedStart(read), read.isReverseStrand(), ReadUtils.getReferenceIndex(read, (SAMFileHeader)bHeader.getValue()), (Byte)((Map)libraryIndex.getValue()).get(MarkDuplicatesSparkUtils.getLibraryForRead(read, (SAMFileHeader)bHeader.getValue(), "Unknown Library"))), read));
        JavaPairRDD secondKeyed = secondReads.mapToPair((PairFunction & Serializable)read -> new Tuple2((Object)ReadsKey.getKeyForFragment(ReadUtils.getStrandedUnclippedStart(read), read.isReverseStrand(), ReadUtils.getReferenceIndex(read, (SAMFileHeader)bHeader.getValue()), (Byte)((Map)libraryIndex.getValue()).get(MarkDuplicatesSparkUtils.getLibraryForRead(read, (SAMFileHeader)bHeader.getValue(), "Unknown Library"))), read));
        JavaPairRDD cogroup = firstKeyed.cogroup(secondKeyed, this.getRecommendedNumReducers());
        JavaRDD subsettedByStart = cogroup.flatMap((FlatMapFunction & Serializable)v1 -> {
            ArrayList<Tuple2> out = new ArrayList<Tuple2>();
            Iterable iFirstReads = (Iterable)((Tuple2)v1._2())._1();
            Iterable iSecondReads = (Iterable)((Tuple2)v1._2())._2();
            Map<Integer, List<GATKRead>> firstReadsMap = CompareDuplicatesSpark.splitByStart(iFirstReads);
            Map<Integer, List<GATKRead>> secondReadsMap = CompareDuplicatesSpark.splitByStart(iSecondReads);
            for (Integer i : firstReadsMap.keySet()) {
                out.add(new Tuple2(firstReadsMap.get(i), secondReadsMap.get(i)));
            }
            return out.iterator();
        });
        if (this.hasOutputSpecified()) {
            JavaRDD unequalGroups = subsettedByStart.filter((Function & Serializable)v1 -> {
                Iterable iSecondReads;
                SAMFileHeader header = (SAMFileHeader)bHeader.getValue();
                Iterable iFirstReads = (Iterable)v1._1();
                MatchType type = CompareDuplicatesSpark.getDupes(iFirstReads, iSecondReads = (Iterable)v1._2(), header);
                return type != MatchType.EQUAL;
            });
            List names = unequalGroups.flatMap((FlatMapFunction & Serializable)v1 -> {
                HashSet out = new HashSet();
                Iterable iFirstReads = (Iterable)v1._1();
                Iterable iSecondReads = (Iterable)v1._2();
                iFirstReads.forEach(read -> out.add(read.getName()));
                iSecondReads.forEach(read -> out.add(read.getName()));
                return out.iterator();
            }).collect();
            Broadcast nameSet = ctx.broadcast(new HashSet(names));
            SAMFileHeader headerForwrite = (SAMFileHeader)bHeader.getValue();
            headerForwrite.setAttribute("in", "original read file source");
            this.writeReads(ctx, this.output, (JavaRDD<GATKRead>)firstReads.filter((Function & Serializable)read -> ((Set)nameSet.value()).contains(read.getName())), headerForwrite, true);
            this.writeReads(ctx, this.output2, (JavaRDD<GATKRead>)secondReads.filter((Function & Serializable)read -> ((Set)nameSet.value()).contains(read.getName())), headerForwrite, true);
        }
        if ((tagCountMap = (tagged = subsettedByStart.map((Function & Serializable)v1 -> {
            SAMFileHeader header = (SAMFileHeader)bHeader.getValue();
            Iterable iFirstReads = (Iterable)v1._1();
            Iterable iSecondReads = (Iterable)v1._2();
            return CompareDuplicatesSpark.getDupes(iFirstReads, iSecondReads, header);
        })).mapToPair((PairFunction & Serializable)v1 -> new Tuple2((Object)v1, (Object)1)).reduceByKey((Function2 & Serializable)(v1, v2) -> v1 + v2).collectAsMap()).get((Object)MatchType.SIZE_UNEQUAL) != null) {
            throw new UserException("The number of reads by the MarkDuplicates key were unequal, indicating that the BAMs are not the same");
        }
        if (tagCountMap.get((Object)MatchType.READ_MISMATCH) != null) {
            throw new UserException("The reads grouped by the MarkDuplicates key were not the same, indicating that the BAMs are not the same");
        }
        if (this.printSummary) {
            MatchType[] values = MatchType.values();
            LinkedHashSet matchTypes = Sets.newLinkedHashSet((Iterable)Sets.newHashSet((Object[])values));
            System.out.println("##############################");
            matchTypes.forEach(s -> System.out.println((Object)s + ": " + tagCountMap.getOrDefault(s, 0)));
        }
        if (this.throwOnDiff) {
            for (MatchType s2 : MatchType.values()) {
                if (s2 == MatchType.EQUAL || tagCountMap.get((Object)s2) == null) continue;
                throw new UserException("found difference between the two BAMs: " + (Object)((Object)s2) + " with count " + tagCountMap.get((Object)s2));
            }
        }
    }

    private boolean hasOutputSpecified() {
        return this.output != null || this.output2 != null;
    }

    private static Map<Integer, List<GATKRead>> splitByStart(Iterable<GATKRead> duplicateGroup) {
        HashMap<Integer, List<GATKRead>> byType = new HashMap<Integer, List<GATKRead>>();
        for (GATKRead read : duplicateGroup) {
            byType.compute(ReadUtils.getStrandedUnclippedStart(read), (key, value) -> {
                if (value == null) {
                    ArrayList<GATKRead> reads = new ArrayList<GATKRead>();
                    reads.add(read);
                    return reads;
                }
                value.add(read);
                return value;
            });
        }
        return byType;
    }

    static MatchType getDupes(Iterable<GATKRead> f, Iterable<GATKRead> s, SAMFileHeader header) {
        ArrayList first = Lists.newArrayList(f);
        ArrayList second = Lists.newArrayList(s);
        if (first.size() != second.size()) {
            return MatchType.SIZE_UNEQUAL;
        }
        int size = first.size();
        first.sort(new ReadCoordinateComparator(header));
        second.sort(new ReadCoordinateComparator(header));
        LinkedHashSet firstDupes = Sets.newLinkedHashSet();
        LinkedHashSet secondDupes = Sets.newLinkedHashSet();
        for (int i = 0; i < size; ++i) {
            GATKRead firstRead = (GATKRead)first.get(i);
            GATKRead secondRead = (GATKRead)second.get(i);
            if (!firstRead.getName().equals(secondRead.getName())) {
                return MatchType.READ_MISMATCH;
            }
            if (firstRead.isDuplicate()) {
                firstDupes.add(firstRead);
            }
            if (!secondRead.isDuplicate()) continue;
            secondDupes.add(secondRead);
        }
        if (firstDupes.size() != secondDupes.size()) {
            return MatchType.DIFF_NUM_DUPES;
        }
        if (!firstDupes.equals(secondDupes)) {
            return MatchType.DIFFERENT_REPRESENTATIVE_READ;
        }
        return MatchType.EQUAL;
    }

    static JavaRDD<GATKRead> removeNonReadGroupAttributes(JavaRDD<GATKRead> initialReads) {
        return initialReads.map((Function & Serializable)v1 -> {
            String rg = v1.getReadGroup();
            v1.clearAttributes();
            v1.setReadGroup(rg);
            return v1;
        });
    }

    static enum MatchType {
        EQUAL,
        SIZE_UNEQUAL,
        READ_MISMATCH,
        DIFF_NUM_DUPES,
        DIFFERENT_REPRESENTATIVE_READ;

    }
}

