/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.spark.bwa;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.SAMSequenceDictionary;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.spark.SparkFiles;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.bwa.BwaMemAligner;
import org.broadinstitute.hellbender.utils.bwa.BwaMemAlignment;
import org.broadinstitute.hellbender.utils.bwa.BwaMemAlignmentUtils;
import org.broadinstitute.hellbender.utils.bwa.BwaMemIndex;
import org.broadinstitute.hellbender.utils.bwa.BwaMemIndexCache;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.SAMRecordToGATKReadAdapter;

public final class BwaSparkEngine
implements AutoCloseable {
    private static final String REFERENCE_INDEX_IMAGE_FILE_SUFFIX = ".img";
    private final JavaSparkContext ctx;
    private final String indexFileName;
    private final boolean resolveIndexFileName;
    private final Broadcast<SAMFileHeader> broadcastHeader;

    public BwaSparkEngine(JavaSparkContext ctx, String referenceFile, String indexFileName, SAMFileHeader inputHeader, SAMSequenceDictionary refDictionary) {
        Utils.nonNull(referenceFile);
        Utils.nonNull(inputHeader);
        this.ctx = ctx;
        if (indexFileName != null) {
            this.indexFileName = indexFileName;
            this.resolveIndexFileName = false;
        } else {
            String indexFile = referenceFile + REFERENCE_INDEX_IMAGE_FILE_SUFFIX;
            ctx.addFile(indexFile);
            this.indexFileName = IOUtils.getPath(indexFile).getFileName().toString();
            this.resolveIndexFileName = true;
        }
        if (inputHeader.getSequenceDictionary() == null || inputHeader.getSequenceDictionary().isEmpty()) {
            Utils.nonNull(refDictionary);
            inputHeader = inputHeader.clone();
            inputHeader.setSequenceDictionary(refDictionary);
        }
        this.broadcastHeader = ctx.broadcast((Object)inputHeader);
    }

    public SAMFileHeader getHeader() {
        return (SAMFileHeader)this.broadcastHeader.getValue();
    }

    public JavaRDD<GATKRead> alignPaired(JavaRDD<GATKRead> unalignedReads) {
        return this.align(unalignedReads, true);
    }

    public JavaRDD<GATKRead> alignUnpaired(JavaRDD<GATKRead> unalignedReads) {
        return this.align(unalignedReads, false);
    }

    public JavaRDD<GATKRead> align(JavaRDD<GATKRead> unalignedReads, boolean pairedAlignment) {
        Broadcast<SAMFileHeader> broadcastHeader = this.broadcastHeader;
        String indexFileName = this.indexFileName;
        boolean resolveIndexFileName = this.resolveIndexFileName;
        return unalignedReads.mapPartitions((FlatMapFunction & Serializable)itr -> new ReadAligner(resolveIndexFileName ? SparkFiles.get((String)indexFileName) : indexFileName, (SAMFileHeader)broadcastHeader.value(), pairedAlignment).apply((Iterator<GATKRead>)itr));
    }

    @Override
    public void close() {
        this.broadcastHeader.destroy();
        BwaMemIndexCache.closeAllDistributedInstances(this.ctx);
    }

    private static final class ReadAligner {
        private final BwaMemIndex bwaMemIndex;
        private final SAMFileHeader readsHeader;
        private final boolean alignsPairs;
        private static final int READS_PER_PARTITION_GUESS = 1500000;

        ReadAligner(String indexFileName, SAMFileHeader readsHeader, boolean alignsPairs) {
            this.bwaMemIndex = BwaMemIndexCache.getInstance(indexFileName);
            this.readsHeader = readsHeader;
            this.alignsPairs = alignsPairs;
            if (alignsPairs && readsHeader.getSortOrder() != SAMFileHeader.SortOrder.queryname) {
                throw new UserException("Input must be queryname sorted unless you use single-ended alignment mode.");
            }
        }

        Iterator<GATKRead> apply(Iterator<GATKRead> readItr) {
            List allAlignments;
            ArrayList<GATKRead> inputReads = new ArrayList<GATKRead>(1500000);
            while (readItr.hasNext()) {
                inputReads.add(readItr.next());
            }
            int nReads = inputReads.size();
            if (this.alignsPairs) {
                if ((nReads & 1) != 0) {
                    throw new GATKException("We're supposed to be aligning paired reads, but there are an odd number of them.");
                }
                for (int idx = 0; idx != nReads; idx += 2) {
                    Object readName2;
                    String readName1 = ((GATKRead)inputReads.get(idx)).getName();
                    if (Objects.equals(readName1, readName2 = ((GATKRead)inputReads.get(idx + 1)).getName())) continue;
                    throw new GATKException("Read pair has varying template name: " + readName1 + " .vs " + (String)readName2);
                }
            }
            if (nReads == 0) {
                allAlignments = Collections.emptyList();
            } else {
                ArrayList<byte[]> seqs = new ArrayList<byte[]>(nReads);
                for (GATKRead read : inputReads) {
                    seqs.add(read.getBases());
                }
                BwaMemAligner aligner = new BwaMemAligner(this.bwaMemIndex);
                if (this.alignsPairs) {
                    aligner.alignPairs();
                }
                allAlignments = aligner.alignSeqs(seqs);
            }
            List refNames = this.bwaMemIndex.getReferenceContigNames();
            ArrayList<SAMRecordToGATKReadAdapter> outputReads = new ArrayList<SAMRecordToGATKReadAdapter>(allAlignments.stream().mapToInt(List::size).sum());
            for (int idx = 0; idx != nReads; ++idx) {
                GATKRead originalRead = (GATKRead)inputReads.get(idx);
                String readName = originalRead.getName();
                byte[] bases = originalRead.getBases();
                byte[] quals = originalRead.getBaseQualities();
                String readGroup = originalRead.getReadGroup();
                List alignments = (List)allAlignments.get(idx);
                Map<BwaMemAlignment, String> saTagMap = BwaMemAlignmentUtils.createSATags(alignments, refNames);
                for (BwaMemAlignment alignment : alignments) {
                    SAMRecord samRecord = BwaMemAlignmentUtils.applyAlignment(readName, bases, quals, readGroup, alignment, refNames, this.readsHeader, false, true);
                    SAMRecordToGATKReadAdapter rec = SAMRecordToGATKReadAdapter.headerlessReadAdapter(samRecord);
                    String saTag = saTagMap.get(alignment);
                    if (saTag != null) {
                        rec.setAttribute("SA", saTag);
                    }
                    outputReads.add(rec);
                }
            }
            return outputReads.iterator();
        }
    }
}

