/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.realignmentfilter;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
import com.google.common.collect.Multisets;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.util.Locatable;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.variantcontext.VariantContextBuilder;
import htsjdk.variant.variantcontext.writer.VariantContextWriter;
import htsjdk.variant.vcf.VCFHeader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.argparser.ExperimentalFeature;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.engine.AssemblyRegion;
import org.broadinstitute.hellbender.engine.GATKPath;
import org.broadinstitute.hellbender.engine.MultiVariantWalkerGroupedOnStart;
import org.broadinstitute.hellbender.engine.ReadsContext;
import org.broadinstitute.hellbender.engine.ReferenceContext;
import org.broadinstitute.hellbender.engine.filters.CountingVariantFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.filters.VariantFilterLibrary;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.AssemblyBasedCallerUtils;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.AssemblyResultSet;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.ReadLikelihoodCalculationEngine;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.readthreading.ReadThreadingAssembler;
import org.broadinstitute.hellbender.tools.walkers.mutect.M2ArgumentCollection;
import org.broadinstitute.hellbender.tools.walkers.mutect.Mutect2Engine;
import org.broadinstitute.hellbender.tools.walkers.realignmentfilter.RealignmentArgumentCollection;
import org.broadinstitute.hellbender.tools.walkers.realignmentfilter.RealignmentEngine;
import org.broadinstitute.hellbender.utils.BaseUtils;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.bwa.BwaMemAlignment;
import org.broadinstitute.hellbender.utils.downsampling.DownsamplingMethod;
import org.broadinstitute.hellbender.utils.fasta.CachingIndexedFastaSequenceFile;
import org.broadinstitute.hellbender.utils.genotyper.AlleleLikelihoods;
import org.broadinstitute.hellbender.utils.genotyper.IndexedSampleList;
import org.broadinstitute.hellbender.utils.genotyper.SampleList;
import org.broadinstitute.hellbender.utils.haplotype.Haplotype;
import org.broadinstitute.hellbender.utils.haplotype.HaplotypeBAMWriter;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.locusiterator.LocusIteratorByState;
import org.broadinstitute.hellbender.utils.pileup.PileupElement;
import org.broadinstitute.hellbender.utils.pileup.ReadPileup;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.ReadUtils;
import org.broadinstitute.hellbender.utils.reference.ReferenceUtils;
import org.broadinstitute.hellbender.utils.smithwaterman.SmithWatermanAligner;
import org.broadinstitute.hellbender.utils.variant.GATKVCFHeaderLines;
import picard.cmdline.programgroups.VariantFilteringProgramGroup;

@CommandLineProgramProperties(summary="Filter alignment artifacts from a vcf callset.", oneLineSummary="Filter alignment artifacts from a vcf callset.", programGroup=VariantFilteringProgramGroup.class)
@DocumentedFeature
@ExperimentalFeature
public class FilterAlignmentArtifacts
extends MultiVariantWalkerGroupedOnStart {
    public static final int DEFAULT_DISTANCE_TO_GROUP_VARIANTS = 1000;
    public static final int DEFAULT_REF_PADDING = 100;
    public static final int DEFAULT_MAX_GROUPED_SPAN = 10000;
    private static final int MIN_UNITIG_LENGTH = 30;
    private static final int ASSEMBLY_PADDING = 50;
    private static final SmithWatermanAligner ALIGNER = SmithWatermanAligner.getAligner(SmithWatermanAligner.Implementation.FASTEST_AVAILABLE);
    @Argument(fullName="output", shortName="O", doc="The output filtered VCF file", optional=false)
    private final GATKPath outputVcf = null;
    public static final int DEFAULT_INDEL_START_TOLERANCE = 5;
    public static final String INDEL_START_TOLERANCE_LONG_NAME = "indel-start-tolerance";
    @Argument(fullName="indel-start-tolerance", doc="Max distance between indel start of aligned read in the bam and the variant in the vcf", optional=true)
    private int indelStartTolerance = 5;
    public static final int DEFAULT_KMER_SIZE = 21;
    public static final String KMER_SIZE_LONG_NAME = "kmer-size";
    @Argument(fullName="kmer-size", doc="Kmer size for reassembly", optional=true)
    private int kmerSize = 21;
    public static final String DONT_SKIP_ALREADY_FILTERED_VARIANTS_LONG_NAME = "dont-skip-filtered-variants";
    @Argument(fullName="dont-skip-filtered-variants", doc="Try to realign all variants, even ones that have already been filtered.", optional=true)
    private boolean dontSkipFilteredVariants = false;
    @Argument(fullName="bam-output", shortName="bamout", doc="File to which assembled haplotypes should be written", optional=true)
    public String bamOutputPath = null;
    @ArgumentCollection
    protected RealignmentArgumentCollection realignmentArgumentCollection = new RealignmentArgumentCollection();
    private VariantContextWriter vcfWriter;
    private RealignmentEngine realignmentEngine;
    private SAMFileHeader bamHeader;
    private SampleList samplesList;
    private CachingIndexedFastaSequenceFile referenceReader;
    private ReadThreadingAssembler assemblyEngine;
    private final M2ArgumentCollection MTAC = new M2ArgumentCollection();
    private ReadLikelihoodCalculationEngine likelihoodCalculationEngine;
    private Optional<HaplotypeBAMWriter> haplotypeBAMWriter;

    @Override
    public List<ReadFilter> getDefaultReadFilters() {
        return Mutect2Engine.makeStandardMutect2ReadFilters();
    }

    @Override
    protected CountingVariantFilter makeVariantFilter() {
        return new CountingVariantFilter(this.dontSkipFilteredVariants ? VariantFilterLibrary.ALLOW_ALL_VARIANTS : VariantFilterLibrary.PASSES_FILTERS);
    }

    @Override
    public boolean requiresReads() {
        return true;
    }

    @Override
    protected int defaultDistanceToGroupVariants() {
        return 1000;
    }

    @Override
    protected int defaultReferenceWindowPadding() {
        return 100;
    }

    @Override
    protected int defaultMaxGroupedSpan() {
        return 10000;
    }

    @Override
    public void onTraversalStart() {
        this.realignmentEngine = new RealignmentEngine(this.realignmentArgumentCollection);
        this.vcfWriter = this.createVCFWriter(this.outputVcf);
        VCFHeader inputHeader = this.getHeaderForVariants();
        HashSet<Object> headerLines = new HashSet<Object>(inputHeader.getMetaDataInSortedOrder());
        headerLines.add(GATKVCFHeaderLines.getFilterLine("alignment"));
        headerLines.add(GATKVCFHeaderLines.getInfoLine("UNITIGS"));
        headerLines.add(GATKVCFHeaderLines.getInfoLine("ALIGN_DIFF"));
        headerLines.add(GATKVCFHeaderLines.getInfoLine("NALIGNS"));
        headerLines.addAll(this.getDefaultToolVCFHeaderLines());
        VCFHeader vcfHeader = new VCFHeader(headerLines, inputHeader.getGenotypeSamples());
        this.vcfWriter.writeHeader(vcfHeader);
        this.bamHeader = this.getHeaderForReads();
        this.samplesList = new IndexedSampleList(new ArrayList<String>(ReadUtils.getSamplesFromHeader(this.bamHeader)));
        this.referenceReader = ReferenceUtils.createReferenceReader(Utils.nonNull(this.referenceArguments.getReferenceSpecifier()));
        this.assemblyEngine = this.MTAC.createReadThreadingAssembler();
        this.likelihoodCalculationEngine = AssemblyBasedCallerUtils.createLikelihoodCalculationEngine(this.MTAC.likelihoodArgs, true);
        this.haplotypeBAMWriter = this.bamOutputPath == null ? Optional.empty() : Optional.of(new HaplotypeBAMWriter(HaplotypeBAMWriter.WriterType.ALL_POSSIBLE_HAPLOTYPES, IOUtils.getPath(this.bamOutputPath), true, false, this.getHeaderForSAMWriter()));
    }

    @Override
    public Object onTraversalSuccess() {
        return "SUCCESS";
    }

    @Override
    public void apply(List<VariantContext> variantContexts, ReferenceContext referenceContext, List<ReadsContext> readsContexts) {
        for (VariantContext vc : variantContexts) {
            AssemblyRegion assemblyRegion = this.makeAssemblyRegionFromVariantReads(readsContexts, vc);
            AssemblyResultSet assemblyResult = AssemblyBasedCallerUtils.assembleReads(assemblyRegion, Collections.emptyList(), this.MTAC, this.bamHeader, this.samplesList, this.logger, this.referenceReader, this.assemblyEngine, ALIGNER, false);
            AssemblyRegion regionForGenotyping = assemblyResult.getRegionForGenotyping();
            Map<String, List<GATKRead>> reads = AssemblyBasedCallerUtils.splitReadsBySample(this.samplesList, this.bamHeader, regionForGenotyping.getReads());
            AlleleLikelihoods<GATKRead, Haplotype> readLikelihoods = this.likelihoodCalculationEngine.computeReadLikelihoods(assemblyResult, this.samplesList, reads);
            readLikelihoods.switchToNaturalLog();
            Map<GATKRead, GATKRead> readRealignments = AssemblyBasedCallerUtils.realignReadsToTheirBestHaplotype(readLikelihoods, assemblyResult.getReferenceHaplotype(), assemblyResult.getPaddedReferenceLoc(), ALIGNER);
            readLikelihoods.changeEvidence(readRealignments);
            this.writeBamOutput(assemblyResult, readLikelihoods, new HashSet<Haplotype>(readLikelihoods.alleles()), regionForGenotyping.getSpan());
            LocusIteratorByState libs = new LocusIteratorByState(regionForGenotyping.getReads().iterator(), DownsamplingMethod.NONE, false, this.samplesList.asListOfSamples(), this.bamHeader, true);
            List<byte[]> unitigs = this.getUnitigs(libs);
            VariantContextBuilder vcb = new VariantContextBuilder(vc).attribute("UNITIGS", (Object)unitigs.stream().mapToInt(u -> ((byte[])u).length).toArray());
            List<List<BwaMemAlignment>> unitigAlignments = unitigs.stream().map(this.realignmentEngine::realign).collect(Collectors.toList());
            List<List<BwaMemAlignment>> jointAlignments = RealignmentEngine.findJointAlignments(unitigAlignments, this.realignmentArgumentCollection.maxReasonableFragmentLength);
            vcb.attribute("NALIGNS", (Object)jointAlignments.size());
            jointAlignments.sort(Comparator.comparingInt(FilterAlignmentArtifacts::jointAlignmentScore).reversed());
            if (!jointAlignments.isEmpty() && jointAlignments.get(0).get(0).getRefId() != this.getReferenceDictionary().getSequenceIndex(vc.getContig())) {
                vcb.filter("alignment");
            } else if (jointAlignments.size() > 1) {
                boolean multimapping;
                int totalBases = unitigs.stream().mapToInt(unitig -> ((byte[])unitig).length).sum();
                int scoreDiff = FilterAlignmentArtifacts.jointAlignmentScore(jointAlignments.get(0)) - FilterAlignmentArtifacts.jointAlignmentScore(jointAlignments.get(1));
                int mismatchDiff = FilterAlignmentArtifacts.totalMismatches(jointAlignments.get(1)) - FilterAlignmentArtifacts.totalMismatches(jointAlignments.get(0));
                vcb.attribute("ALIGN_DIFF", (Object)scoreDiff);
                boolean bl = multimapping = (double)scoreDiff / (double)totalBases < this.realignmentArgumentCollection.minAlignerScoreDifferencePerBase && (double)mismatchDiff / (double)totalBases < this.realignmentArgumentCollection.minMismatchDifferencePerBase;
                if (multimapping) {
                    vcb.filter("alignment");
                }
            }
            this.vcfWriter.add(vcb.make());
        }
    }

    private AssemblyRegion makeAssemblyRegionFromVariantReads(List<ReadsContext> readsContexts, VariantContext vc) {
        Set variantReadNames = readsContexts.stream().flatMap(Utils::stream).filter(read -> RealignmentEngine.supportsVariant(read, vc, this.indelStartTolerance)).map(GATKRead::getName).collect(Collectors.toSet());
        List<GATKRead> variantReads = readsContexts.stream().flatMap(Utils::stream).filter(read -> variantReadNames.contains(read.getName())).sorted(Comparator.comparingInt(Locatable::getStart)).collect(Collectors.toList());
        int firstReadStart = variantReads.stream().mapToInt(Locatable::getStart).min().orElse(vc.getStart());
        int lastReadEnd = variantReads.stream().mapToInt(Locatable::getEnd).max().orElse(vc.getEnd());
        SimpleInterval assemblyWindow = new SimpleInterval(vc.getContig(), Math.max(firstReadStart - 50, 1), lastReadEnd + 50);
        AssemblyRegion assemblyRegion = new AssemblyRegion(assemblyWindow, 0, this.bamHeader);
        assemblyRegion.addAll(variantReads);
        return assemblyRegion;
    }

    private void writeBamOutput(AssemblyResultSet assemblyResult, AlleleLikelihoods<GATKRead, Haplotype> readLikelihoods, Set<Haplotype> haplotypes, Locatable callableRegion) {
        this.haplotypeBAMWriter.ifPresent(writer -> writer.writeReadsAlignedToHaplotypes(assemblyResult.getHaplotypeList(), assemblyResult.getPaddedReferenceLoc(), assemblyResult.getHaplotypeList(), haplotypes, readLikelihoods, callableRegion));
    }

    private List<byte[]> getUnitigs(LocusIteratorByState libs) {
        ArrayList<StringBuilder> unitigBuilders = new ArrayList<StringBuilder>();
        int lastCoveredLocus = Integer.MIN_VALUE;
        while (libs.hasNext()) {
            ReadPileup pileup = libs.next().getBasePileup();
            if (pileup.isEmpty()) continue;
            int currentLocus = pileup.getLocation().getStart();
            if (currentLocus != lastCoveredLocus + 1) {
                unitigBuilders.add(new StringBuilder());
            }
            lastCoveredLocus = currentLocus;
            StringBuilder currentUnitigBuilder = (StringBuilder)unitigBuilders.get(unitigBuilders.size() - 1);
            int[] baseCounts = pileup.getBaseCounts();
            int deletionCount = (int)Utils.stream(pileup).filter(PileupElement::isDeletion).count();
            if (deletionCount >= pileup.size() / 2) continue;
            byte consensusBase = BaseUtils.baseIndexToSimpleBase(MathUtils.maxElementIndex(baseCounts));
            currentUnitigBuilder.append((char)consensusBase);
            Multiset insertedBases = (Multiset)Utils.stream(pileup).map(PileupElement::getBasesOfImmediatelyFollowingInsertion).filter(s -> s != null).collect(Collectors.toCollection(HashMultiset::create));
            if (insertedBases.size() <= pileup.size() / 2) continue;
            String consensusInsertion = (String)((Multiset.Entry)Multisets.copyHighestCountFirst((Multiset)insertedBases).entrySet().iterator().next()).getElement();
            currentUnitigBuilder.append(consensusInsertion);
        }
        return unitigBuilders.stream().map(builder -> builder.toString().getBytes()).filter(unitig -> ((byte[])unitig).length > 30).collect(Collectors.toList());
    }

    private static int jointAlignmentScore(List<BwaMemAlignment> alignments) {
        return alignments.stream().mapToInt(BwaMemAlignment::getAlignerScore).sum();
    }

    private static int totalMismatches(List<BwaMemAlignment> alignments) {
        return alignments.stream().mapToInt(BwaMemAlignment::getNMismatches).sum();
    }

    @Override
    public void closeTool() {
        if (this.vcfWriter != null) {
            this.vcfWriter.close();
        }
    }
}

