/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.spark.pathseq;

import htsjdk.samtools.SAMFileHeader;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.storage.StorageLevel;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.GATKPlugin.GATKReadFilterPluginDescriptor;
import org.broadinstitute.hellbender.cmdline.programgroups.MetagenomicsProgramGroup;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.spark.pathseq.PSBwaAlignerSpark;
import org.broadinstitute.hellbender.tools.spark.pathseq.PSBwaArgumentCollection;
import org.broadinstitute.hellbender.tools.spark.pathseq.PSBwaUtils;
import org.broadinstitute.hellbender.tools.spark.pathseq.PSFilter;
import org.broadinstitute.hellbender.tools.spark.pathseq.PSFilterArgumentCollection;
import org.broadinstitute.hellbender.tools.spark.pathseq.PSScoreArgumentCollection;
import org.broadinstitute.hellbender.tools.spark.pathseq.PSScorer;
import org.broadinstitute.hellbender.tools.spark.pathseq.PSUtils;
import org.broadinstitute.hellbender.tools.spark.pathseq.loggers.PSFilterEmptyLogger;
import org.broadinstitute.hellbender.tools.spark.pathseq.loggers.PSFilterFileLogger;
import org.broadinstitute.hellbender.tools.spark.pathseq.loggers.PSFilterLogger;
import org.broadinstitute.hellbender.tools.spark.pathseq.loggers.PSScoreFileLogger;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.ReadsWriteFormat;
import scala.Tuple2;

@CommandLineProgramProperties(summary="Combined tool that performs all PathSeq steps: read filtering, microbe reference alignment and abundance scoring", oneLineSummary="Combined tool that performs all steps: read filtering, microbe reference alignment, and abundance scoring", programGroup=MetagenomicsProgramGroup.class)
@DocumentedFeature
public class PathSeqPipelineSpark
extends GATKSparkTool {
    private static final long serialVersionUID = 1L;
    public static final String READS_PER_PARTITION_LONG_NAME = "pipeline-reads-per-partition";
    @ArgumentCollection
    public PSFilterArgumentCollection filterArgs = new PSFilterArgumentCollection();
    @ArgumentCollection
    public PSBwaArgumentCollection bwaArgs = new PSBwaArgumentCollection();
    @ArgumentCollection
    public PSScoreArgumentCollection scoreArgs = new PSScoreArgumentCollection();
    @Argument(doc="Output BAM", fullName="output", shortName="O", optional=true)
    public String outputPath = null;
    @Argument(doc="Number of reads per partition to use for alignment and scoring.", fullName="pipeline-reads-per-partition", optional=true, minValue=100.0)
    public int readsPerPartition = 5000;
    @Argument(doc="Number of reads per partition for output. Use this to control the number of sharded BAMs (not --num-reducers).", fullName="readsPerPartitionOutput", optional=true, minValue=100.0, minRecommendedValue=100000.0)
    public int readsPerPartitionOutput = 1000000;

    private static JavaRDD<GATKRead> repartitionPairedReads(JavaRDD<GATKRead> pairedReads, int alignmentPartitions, long numReads) {
        int readsPerPartition = 1 + (int)(numReads / (long)alignmentPartitions);
        return pairedReads.mapPartitions((FlatMapFunction & Serializable)iter -> PathSeqPipelineSpark.pairPartitionReads(iter, readsPerPartition)).repartition(alignmentPartitions).flatMap(List::iterator);
    }

    private static Iterator<List<GATKRead>> pairPartitionReads(Iterator<GATKRead> iter, int readsPerPartition) {
        ArrayList readPairs = new ArrayList(readsPerPartition / 2);
        while (iter.hasNext()) {
            ArrayList<GATKRead> list = new ArrayList<GATKRead>(2);
            list.add(iter.next());
            if (!iter.hasNext()) {
                throw new GATKException("Odd number of read pairs in paired reads partition");
            }
            list.add(iter.next());
            if (!((GATKRead)list.get(0)).getName().equals(((GATKRead)list.get(1)).getName())) {
                throw new GATKException("Pair did not have the same name in a paired reads partition");
            }
            readPairs.add(list);
        }
        return readPairs.iterator();
    }

    @Override
    public boolean requiresReads() {
        return true;
    }

    @Override
    protected void runTool(JavaSparkContext ctx) {
        Tuple2<JavaRDD<GATKRead>, JavaRDD<GATKRead>> filterResult;
        this.filterArgs.doReadFilterArgumentWarnings((GATKReadFilterPluginDescriptor)((Object)this.getCommandLineParser().getPluginDescriptor(GATKReadFilterPluginDescriptor.class)), this.logger);
        SAMFileHeader header = PSUtils.checkAndClearHeaderSequences(this.getHeaderForReads(), this.filterArgs, this.logger);
        if (this.numReducers > 0) {
            throw new UserException.BadInput("Use --readsPerPartitionOutput instead of --num-reducers.");
        }
        PSFilter filter = new PSFilter(ctx, this.filterArgs, header);
        try (PSFilterLogger filterLogger = this.filterArgs.filterMetricsFileUri != null ? new PSFilterFileLogger(this.getMetricsFile(), this.filterArgs.filterMetricsFileUri) : new PSFilterEmptyLogger();){
            JavaRDD<GATKRead> inputReads = this.getReads();
            filterResult = filter.doFilter(inputReads, filterLogger);
        }
        JavaRDD<GATKRead> pairedReads = (JavaRDD<GATKRead>)filterResult._1;
        JavaRDD unpairedReads = (JavaRDD)filterResult._2;
        long numPairedReads = pairedReads.count();
        long numUnpairedReads = unpairedReads.count();
        long numTotalReads = numPairedReads + numUnpairedReads;
        filter.close();
        int numPairedPartitions = 1 + (int)(numPairedReads / (long)this.readsPerPartition);
        int numUnpairedPartitions = 1 + (int)(numUnpairedReads / (long)this.readsPerPartition);
        pairedReads = PathSeqPipelineSpark.repartitionPairedReads(pairedReads, numPairedPartitions, numPairedReads);
        unpairedReads = unpairedReads.repartition(numUnpairedPartitions);
        PSBwaAlignerSpark aligner = new PSBwaAlignerSpark(ctx, this.bwaArgs);
        PSBwaUtils.addReferenceSequencesToHeader(header, this.bwaArgs.microbeDictionary);
        Broadcast headerBroadcast = ctx.broadcast((Object)header);
        JavaRDD<GATKRead> alignedPairedReads = aligner.doBwaAlignment(pairedReads, true, (Broadcast<SAMFileHeader>)headerBroadcast);
        JavaRDD<GATKRead> alignedUnpairedReads = aligner.doBwaAlignment((JavaRDD<GATKRead>)unpairedReads, false, (Broadcast<SAMFileHeader>)headerBroadcast);
        alignedPairedReads.persist(StorageLevel.MEMORY_AND_DISK_SER());
        alignedUnpairedReads.persist(StorageLevel.MEMORY_AND_DISK_SER());
        PSScorer scorer = new PSScorer(this.scoreArgs);
        JavaRDD<GATKRead> readsFinal = scorer.scoreReads(ctx, alignedPairedReads, alignedUnpairedReads, header);
        header = PSBwaUtils.removeUnmappedHeaderSequences(header, readsFinal, this.logger);
        if (this.scoreArgs.scoreMetricsFileUri != null) {
            try (PSScoreFileLogger scoreLogger = new PSScoreFileLogger(this.getMetricsFile(), this.scoreArgs.scoreMetricsFileUri);){
                scoreLogger.logReadCounts(readsFinal);
            }
        }
        if (this.outputPath != null) {
            try {
                int numPartitions = Math.max(1, (int)(numTotalReads / (long)this.readsPerPartitionOutput));
                JavaRDD readsFinalRepartitioned = readsFinal.coalesce(numPartitions, false);
                ReadsSparkSink.writeReads(ctx, this.outputPath, null, (JavaRDD<GATKRead>)readsFinalRepartitioned, header, this.shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, numPartitions, this.shardedPartsDir, true, this.splittingIndexGranularity);
            }
            catch (IOException e) {
                throw new UserException.CouldNotCreateOutputFile(this.outputPath, "writing failed", (Exception)e);
            }
        }
        aligner.close();
    }
}

