/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.engine.spark;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.util.Locatable;
import java.io.Serializable;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Spliterators;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.spark.SparkFiles;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.CommandLineException;
import org.broadinstitute.hellbender.engine.AlignmentContext;
import org.broadinstitute.hellbender.engine.FeatureContext;
import org.broadinstitute.hellbender.engine.FeatureManager;
import org.broadinstitute.hellbender.engine.ReferenceContext;
import org.broadinstitute.hellbender.engine.ReferenceFileSource;
import org.broadinstitute.hellbender.engine.Shard;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
import org.broadinstitute.hellbender.engine.spark.LocusWalkerContext;
import org.broadinstitute.hellbender.engine.spark.SparkSharder;
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.locusiterator.AlignmentContextIteratorBuilder;
import org.broadinstitute.hellbender.utils.locusiterator.LIBSDownsamplingInfo;
import org.broadinstitute.hellbender.utils.locusiterator.LocusIteratorByState;
import org.broadinstitute.hellbender.utils.read.GATKRead;

public abstract class LocusWalkerSpark
extends GATKSparkTool {
    private static final long serialVersionUID = 1L;
    @Argument(fullName="max-depth-per-sample", shortName="max-depth-per-sample", doc="Maximum number of reads to retain per sample per locus. Reads above this threshold will be downsampled. Set to 0 to disable.", optional=true)
    protected int maxDepthPerSample = this.defaultMaxDepthPerSample();
    @Argument(fullName="read-shard-size", shortName="read-shard-size", doc="Maximum size of each read shard, in bases.", optional=true)
    public int readShardSize = 10000;
    @Argument(doc="whether to use the shuffle implementation or overlaps partitioning (the default)", shortName="shuffle", fullName="shuffle", optional=true)
    public boolean shuffle = false;
    private String referenceFileName;

    protected int defaultMaxDepthPerSample() {
        return 0;
    }

    @Override
    public boolean requiresReads() {
        return true;
    }

    protected final LIBSDownsamplingInfo getDownsamplingInfo() {
        if (this.maxDepthPerSample < 0) {
            throw new CommandLineException.BadArgumentValue("max-depth-per-sample", String.valueOf(this.maxDepthPerSample), "should be a positive number");
        }
        return this.maxDepthPerSample == 0 ? LocusIteratorByState.NO_DOWNSAMPLING : new LIBSDownsamplingInfo(true, this.maxDepthPerSample);
    }

    public boolean emitEmptyLoci() {
        return false;
    }

    public JavaRDD<LocusWalkerContext> getAlignments(JavaSparkContext ctx) {
        SAMSequenceDictionary sequenceDictionary = this.getBestAvailableSequenceDictionary();
        List<SimpleInterval> intervals = this.hasUserSuppliedIntervals() ? this.getIntervals() : IntervalUtils.getAllIntervalsForReference(sequenceDictionary);
        List intervalShards = intervals.stream().flatMap(interval -> Shard.divideIntervalIntoShards(interval, this.readShardSize, 0, sequenceDictionary).stream()).collect(Collectors.toList());
        JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, this.getReads(), GATKRead.class, sequenceDictionary, intervalShards, this.readShardSize, this.shuffle);
        Broadcast bFeatureManager = this.features == null ? null : ctx.broadcast((Object)this.features);
        return shardedReads.flatMap(LocusWalkerSpark.getAlignmentsFunction(this.referenceFileName, (Broadcast<FeatureManager>)bFeatureManager, sequenceDictionary, this.getHeaderForReads(), this.getDownsamplingInfo(), this.emitEmptyLoci()));
    }

    private static FlatMapFunction<Shard<GATKRead>, LocusWalkerContext> getAlignmentsFunction(String referenceFileName, Broadcast<FeatureManager> bFeatureManager, SAMSequenceDictionary sequenceDictionary, SAMFileHeader header, LIBSDownsamplingInfo downsamplingInfo, boolean isEmitEmptyLoci) {
        return (FlatMapFunction & Serializable)shardedRead -> {
            SimpleInterval interval = shardedRead.getInterval();
            Iterator<GATKRead> readIterator = shardedRead.iterator();
            ReferenceFileSource reference = referenceFileName == null ? null : new ReferenceFileSource(IOUtils.getPath(SparkFiles.get((String)referenceFileName)));
            FeatureManager fm = bFeatureManager == null ? null : (FeatureManager)bFeatureManager.getValue();
            AlignmentContextIteratorBuilder alignmentContextIteratorBuilder = new AlignmentContextIteratorBuilder();
            alignmentContextIteratorBuilder.setDownsamplingInfo(downsamplingInfo);
            alignmentContextIteratorBuilder.setEmitEmptyLoci(isEmitEmptyLoci);
            alignmentContextIteratorBuilder.setKeepUniqueReadListInLibs(false);
            alignmentContextIteratorBuilder.setIncludeNs(false);
            Iterator<AlignmentContext> alignmentContextIterator = alignmentContextIteratorBuilder.build(readIterator, header, Collections.singletonList(interval), sequenceDictionary, true);
            return StreamSupport.stream(Spliterators.spliteratorUnknownSize(alignmentContextIterator, 0), false).map(alignmentContext -> {
                SimpleInterval alignmentInterval = new SimpleInterval((Locatable)alignmentContext);
                return new LocusWalkerContext((AlignmentContext)alignmentContext, new ReferenceContext(reference, alignmentInterval), new FeatureContext(fm, alignmentInterval));
            }).iterator();
        };
    }

    @Override
    protected void runTool(JavaSparkContext ctx) {
        this.referenceFileName = LocusWalkerSpark.addReferenceFilesForSpark(ctx, this.referenceArguments.getReferencePath());
        this.processAlignments(this.getAlignments(ctx), ctx);
    }

    protected abstract void processAlignments(JavaRDD<LocusWalkerContext> var1, JavaSparkContext var2);
}

