/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.mutect;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Doubles;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.util.Locatable;
import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.GenotypeBuilder;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.variantcontext.VariantContextBuilder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.collections4.ListUtils;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.DefaultRealMatrixChangingVisitor;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
import org.apache.commons.math3.util.FastMath;
import org.broadinstitute.hellbender.engine.FeatureContext;
import org.broadinstitute.hellbender.engine.ReferenceContext;
import org.broadinstitute.hellbender.tools.walkers.annotator.VariantAnnotatorEngine;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.AssemblyBasedCallerUtils;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.AssemblyResultSet;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.CalledHaplotypes;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.ReferenceConfidenceUtils;
import org.broadinstitute.hellbender.tools.walkers.mutect.M2ArgumentCollection;
import org.broadinstitute.hellbender.tools.walkers.mutect.Mutect2Engine;
import org.broadinstitute.hellbender.tools.walkers.mutect.PerAlleleCollection;
import org.broadinstitute.hellbender.tools.walkers.mutect.SomaticLikelihoodsEngine;
import org.broadinstitute.hellbender.tools.walkers.mutect.SubsettedLikelihoodMatrix;
import org.broadinstitute.hellbender.utils.IndexRange;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.NaturalLogUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.genotyper.AlleleLikelihoods;
import org.broadinstitute.hellbender.utils.genotyper.AlleleList;
import org.broadinstitute.hellbender.utils.genotyper.LikelihoodMatrix;
import org.broadinstitute.hellbender.utils.genotyper.SampleList;
import org.broadinstitute.hellbender.utils.haplotype.EventMap;
import org.broadinstitute.hellbender.utils.haplotype.Haplotype;
import org.broadinstitute.hellbender.utils.read.Fragment;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.variant.GATKVariantContextUtils;

public class SomaticGenotypingEngine {
    private final M2ArgumentCollection MTAC;
    private final Set<String> normalSamples;
    final boolean hasNormal;
    protected VariantAnnotatorEngine annotationEngine;
    private final double refPseudocount = 1.0;
    private final double altPseudocount;

    public SomaticGenotypingEngine(M2ArgumentCollection MTAC, Set<String> normalSamples, VariantAnnotatorEngine annotationEngine) {
        this.MTAC = MTAC;
        this.altPseudocount = MTAC.minAF == 0.0 ? 1.0 : 1.0 - Math.log(2.0) / Math.log(MTAC.minAF);
        this.normalSamples = normalSamples;
        this.hasNormal = !normalSamples.isEmpty();
        this.annotationEngine = annotationEngine;
    }

    public CalledHaplotypes callMutations(AlleleLikelihoods<GATKRead, Haplotype> logReadLikelihoods, AssemblyResultSet assemblyResultSet, ReferenceContext referenceContext, SimpleInterval activeRegionWindow, FeatureContext featureContext, List<VariantContext> givenAlleles, SAMFileHeader header, boolean withBamOut, boolean emitRefConf) {
        Utils.nonNull(logReadLikelihoods);
        Utils.validateArg(logReadLikelihoods.numberOfSamples() > 0, "likelihoods have no samples");
        Utils.nonNull(activeRegionWindow);
        List<Haplotype> haplotypes = logReadLikelihoods.alleles();
        List startPosKeySet = EventMap.buildEventMapsForHaplotypes(haplotypes, assemblyResultSet.getFullReferenceWithPadding(), assemblyResultSet.getPaddedReferenceLoc(), this.MTAC.assemblerArgs.debugAssembly, this.MTAC.maxMnpDistance).stream().filter(loc -> activeRegionWindow.getStart() <= loc && loc <= activeRegionWindow.getEnd()).collect(Collectors.toList());
        HashSet<Haplotype> calledHaplotypes = new HashSet<Haplotype>();
        ArrayList<VariantContext> returnCalls = new ArrayList<VariantContext>();
        if (withBamOut) {
            AssemblyBasedCallerUtils.annotateReadLikelihoodsWithRegions(logReadLikelihoods, activeRegionWindow);
        }
        if (this.MTAC.likelihoodArgs.phredScaledGlobalReadMismappingRate > 0) {
            logReadLikelihoods.normalizeLikelihoods(NaturalLogUtils.qualToLogErrorProb(this.MTAC.likelihoodArgs.phredScaledGlobalReadMismappingRate), true);
        }
        AlleleLikelihoods<Fragment, Haplotype> logFragmentLikelihoods = logReadLikelihoods.groupEvidence(this.MTAC.independentMates ? read -> read : GATKRead::getName, Fragment::createAndAvoidFailure);
        Iterator iterator = startPosKeySet.iterator();
        while (iterator.hasNext()) {
            int loc2 = (Integer)iterator.next();
            List<VariantContext> eventsAtThisLoc = AssemblyBasedCallerUtils.getVariantContextsFromActiveHaplotypes(loc2, haplotypes, false);
            VariantContext mergedVC = AssemblyBasedCallerUtils.makeMergedVariantContext(eventsAtThisLoc);
            if (mergedVC == null) continue;
            Map<Allele, List<Haplotype>> alleleMapper = AssemblyBasedCallerUtils.createAlleleMapper(mergedVC, loc2, haplotypes, true);
            AlleleLikelihoods<Fragment, Allele> logLikelihoods = logFragmentLikelihoods.marginalize(alleleMapper);
            SimpleInterval variantCallingRelevantFragmentOverlap = new SimpleInterval((Locatable)mergedVC).expandWithinContig(this.MTAC.informativeReadOverlapMargin, header.getSequenceDictionary());
            logLikelihoods.retainEvidence(variantCallingRelevantFragmentOverlap::overlaps);
            if (emitRefConf) {
                mergedVC = ReferenceConfidenceUtils.addNonRefSymbolicAllele(mergedVC);
                logLikelihoods.addNonReferenceAllele(Allele.NON_REF_ALLELE);
            }
            List tumorMatrices = IntStream.range(0, logLikelihoods.numberOfSamples()).filter(n -> !this.normalSamples.contains(logLikelihoods.getSample(n))).mapToObj(logLikelihoods::sampleMatrix).collect(Collectors.toList());
            AlleleList alleleList = (AlleleList)tumorMatrices.get(0);
            LikelihoodMatrix logTumorMatrix = SomaticGenotypingEngine.combinedLikelihoodMatrix(tumorMatrices, alleleList);
            PerAlleleCollection<Double> tumorLogOdds = this.somaticLogOdds(logTumorMatrix);
            List normalMatrices = IntStream.range(0, logLikelihoods.numberOfSamples()).filter(n -> this.normalSamples.contains(logLikelihoods.getSample(n))).mapToObj(logLikelihoods::sampleMatrix).collect(Collectors.toList());
            LikelihoodMatrix logNormalMatrix = SomaticGenotypingEngine.combinedLikelihoodMatrix(normalMatrices, alleleList);
            PerAlleleCollection<Double> normalLogOdds = this.diploidAltLogOdds(logNormalMatrix);
            PerAlleleCollection<Double> normalArtifactLogOdds = this.somaticLogOdds(logNormalMatrix);
            Set<Allele> forcedAlleles = AssemblyBasedCallerUtils.getAllelesConsistentWithGivenAlleles(givenAlleles, mergedVC);
            List<Allele> tumorAltAlleles = mergedVC.getAlternateAlleles().stream().filter(allele -> forcedAlleles.contains(allele) || (Double)tumorLogOdds.getAlt((Allele)allele) > this.MTAC.getEmissionLogOdds()).collect(Collectors.toList());
            long somaticAltCount = tumorAltAlleles.stream().filter(allele -> forcedAlleles.contains(allele) || !this.hasNormal || this.MTAC.genotypeGermlineSites || (Double)normalLogOdds.getAlt((Allele)allele) > MathUtils.log10ToLog(this.MTAC.normalLog10Odds)).count();
            if (somaticAltCount == 0L) continue;
            List allAllelesToEmit = ListUtils.union(Arrays.asList(mergedVC.getReference()), tumorAltAlleles);
            Map<String, Object> negativeLogPopulationAFAnnotation = SomaticGenotypingEngine.getNegativeLogPopulationAFAnnotation(featureContext.getValues(this.MTAC.germlineResource, loc2), tumorAltAlleles, this.MTAC.getDefaultAlleleFrequency());
            VariantContextBuilder callVcb = new VariantContextBuilder(mergedVC).alleles((Collection)allAllelesToEmit).attributes(negativeLogPopulationAFAnnotation).attribute("TLOD", (Object)tumorAltAlleles.stream().mapToDouble(a -> MathUtils.logToLog10((Double)tumorLogOdds.getAlt((Allele)a))).toArray());
            if (this.hasNormal) {
                callVcb.attribute("NALOD", (Object)Arrays.stream(normalArtifactLogOdds.asDoubleArray(tumorAltAlleles)).map(x -> -MathUtils.logToLog10(x)).toArray());
                callVcb.attribute("NLOD", (Object)Arrays.stream(normalLogOdds.asDoubleArray(tumorAltAlleles)).map(MathUtils::logToLog10).toArray());
            }
            if (!featureContext.getValues(this.MTAC.pon, mergedVC.getStart()).isEmpty()) {
                callVcb.attribute("PON", (Object)true);
            }
            this.addGenotypes(logLikelihoods, allAllelesToEmit, callVcb);
            VariantContext call = callVcb.make();
            VariantContext trimmedCall = GATKVariantContextUtils.trimAlleles(call, true, true);
            List trimmedAlleles = trimmedCall.getAlleles();
            List untrimmedAlleles = call.getAlleles();
            Map<Allele, List> trimmedToUntrimmedAlleleMap = IntStream.range(0, trimmedCall.getNAlleles()).boxed().collect(Collectors.toMap(n -> (Allele)trimmedAlleles.get((int)n), n -> Arrays.asList((Allele)untrimmedAlleles.get((int)n))));
            AlleleLikelihoods<Fragment, Allele> trimmedLikelihoods = logLikelihoods.marginalize(trimmedToUntrimmedAlleleMap);
            AlleleLikelihoods<GATKRead, Allele> logReadAlleleLikelihoods = logReadLikelihoods.marginalize(alleleMapper);
            logReadAlleleLikelihoods.retainEvidence(variantCallingRelevantFragmentOverlap::overlaps);
            if (emitRefConf) {
                logReadAlleleLikelihoods.addNonReferenceAllele(Allele.NON_REF_ALLELE);
            }
            AlleleLikelihoods<GATKRead, Allele> trimmedLikelihoodsForAnnotation = logReadAlleleLikelihoods.marginalize(trimmedToUntrimmedAlleleMap);
            VariantContext annotatedCall = this.annotationEngine.annotateContext(trimmedCall, featureContext, referenceContext, trimmedLikelihoodsForAnnotation, a -> true);
            if (withBamOut) {
                AssemblyBasedCallerUtils.annotateReadLikelihoodsWithSupportedAlleles(trimmedCall, trimmedLikelihoods, Fragment::getReads);
            }
            call.getAlleles().stream().map(alleleMapper::get).filter(Objects::nonNull).forEach(calledHaplotypes::addAll);
            returnCalls.add(annotatedCall);
        }
        List<VariantContext> outputCalls = AssemblyBasedCallerUtils.phaseCalls(returnCalls, calledHaplotypes);
        int eventCount = outputCalls.size();
        List<VariantContext> outputCallsWithEventCountAnnotation = outputCalls.stream().map(vc -> new VariantContextBuilder(vc).attribute("ECNT", (Object)eventCount).make()).collect(Collectors.toList());
        return new CalledHaplotypes(outputCallsWithEventCountAnnotation, calledHaplotypes);
    }

    private double[] makePriorPseudocounts(int numAlleles) {
        return new IndexRange(0, numAlleles).mapToDouble(n -> n == 0 ? 1.0 : this.altPseudocount);
    }

    protected <EVIDENCE extends Locatable> PerAlleleCollection<Double> somaticLogOdds(LikelihoodMatrix<EVIDENCE, Allele> logMatrix) {
        int alleleListEnd = logMatrix.alleles().size() - 1;
        if (logMatrix.alleles().contains(Allele.NON_REF_ALLELE) && !logMatrix.alleles().get(alleleListEnd).equals((Object)Allele.NON_REF_ALLELE)) {
            throw new IllegalStateException("<NON_REF> must be last in the allele list.");
        }
        double logEvidenceWithAllAlleles = logMatrix.evidenceCount() == 0 ? 0.0 : SomaticLikelihoodsEngine.logEvidence(SomaticGenotypingEngine.getAsRealMatrix(logMatrix), this.makePriorPseudocounts(logMatrix.numberOfAlleles()));
        PerAlleleCollection<Double> lods = new PerAlleleCollection<Double>(PerAlleleCollection.Type.ALT_ONLY);
        int refIndex = this.getRefIndex(logMatrix);
        IntStream.range(0, logMatrix.numberOfAlleles()).filter(a -> a != refIndex).forEach(a -> {
            Object allele = logMatrix.getAllele(a);
            SubsettedLikelihoodMatrix logMatrixWithoutThisAllele = SubsettedLikelihoodMatrix.excludingAllele(logMatrix, allele);
            double logEvidenceWithoutThisAllele = logMatrixWithoutThisAllele.evidenceCount() == 0 ? 0.0 : SomaticLikelihoodsEngine.logEvidence(SomaticGenotypingEngine.getAsRealMatrix(logMatrixWithoutThisAllele), this.makePriorPseudocounts(logMatrixWithoutThisAllele.numberOfAlleles()));
            lods.setAlt((Allele)allele, logEvidenceWithAllAlleles - logEvidenceWithoutThisAllele);
        });
        return lods;
    }

    private <EVIDENCE extends Locatable> void addGenotypes(AlleleLikelihoods<EVIDENCE, Allele> logLikelihoods, List<Allele> allelesToEmit, VariantContextBuilder callVcb) {
        List genotypes = IntStream.range(0, logLikelihoods.numberOfSamples()).mapToObj(n -> {
            String sample = logLikelihoods.getSample(n);
            SubsettedLikelihoodMatrix logMatrix = new SubsettedLikelihoodMatrix(logLikelihoods.sampleMatrix(n), allelesToEmit);
            double[] alleleCounts = SomaticGenotypingEngine.getEffectiveCounts(logMatrix);
            double[] flatPriorPseudocounts = new IndexRange(0, logMatrix.numberOfAlleles()).mapToDouble(a -> 1.0);
            double[] alleleFractionsPosterior = logMatrix.evidenceCount() == 0 ? flatPriorPseudocounts : SomaticLikelihoodsEngine.alleleFractionsPosterior(SomaticGenotypingEngine.getAsRealMatrix(logMatrix), flatPriorPseudocounts);
            double[] tumorAlleleFractionsMean = MathUtils.normalizeSumToOne(alleleFractionsPosterior);
            Object ref = logMatrix.getAllele(this.getRefIndex(logMatrix));
            return new GenotypeBuilder(sample, this.normalSamples.contains(sample) ? Collections.nCopies(2, ref) : logMatrix.alleles()).AD(Arrays.stream(alleleCounts).mapToInt(x -> (int)FastMath.round((double)x)).toArray()).attribute("AF", (Object)Arrays.copyOfRange(tumorAlleleFractionsMean, 1, tumorAlleleFractionsMean.length)).make();
        }).collect(Collectors.toList());
        callVcb.genotypes(genotypes);
    }

    private static <EVIDENCE> double[] getEffectiveCounts(LikelihoodMatrix<EVIDENCE, Allele> logLikelihoodMatrix) {
        if (logLikelihoodMatrix.evidenceCount() == 0) {
            return new double[logLikelihoodMatrix.numberOfAlleles()];
        }
        RealMatrix logLikelihoods = SomaticGenotypingEngine.getAsRealMatrix(logLikelihoodMatrix);
        return MathUtils.sumArrayFunction(0, logLikelihoods.getColumnDimension(), read -> NaturalLogUtils.normalizeFromLogToLinearSpace(logLikelihoods.getColumn(read)));
    }

    private <EVIDENCE extends Locatable> PerAlleleCollection<Double> diploidAltLogOdds(LikelihoodMatrix<EVIDENCE, Allele> matrix) {
        int refIndex = this.getRefIndex(matrix);
        int numReads = matrix.evidenceCount();
        double homRefLogLikelihood = new IndexRange(0, numReads).sum(r -> matrix.get(refIndex, r));
        PerAlleleCollection<Double> result = new PerAlleleCollection<Double>(PerAlleleCollection.Type.ALT_ONLY);
        IntStream.range(0, matrix.numberOfAlleles()).filter(a -> a != refIndex).forEach(a -> {
            double hetLogLikelihood = new IndexRange(0, numReads).sum(r -> NaturalLogUtils.logSumExp(matrix.get(refIndex, r), matrix.get(a, r)) + NaturalLogUtils.LOG_ONE_HALF);
            result.setAlt((Allele)matrix.getAllele(a), homRefLogLikelihood - hetLogLikelihood);
        });
        return result;
    }

    private <EVIDENCE> int getRefIndex(LikelihoodMatrix<EVIDENCE, Allele> matrix) {
        OptionalInt optionalRefIndex = IntStream.range(0, matrix.numberOfAlleles()).filter(a -> matrix.getAllele(a).isReference()).findFirst();
        Utils.validateArg(optionalRefIndex.isPresent(), "No ref allele found in likelihoods");
        return optionalRefIndex.getAsInt();
    }

    public static <EVIDENCE> RealMatrix getAsRealMatrix(final LikelihoodMatrix<EVIDENCE, Allele> matrix) {
        Array2DRowRealMatrix result = new Array2DRowRealMatrix(matrix.numberOfAlleles(), matrix.evidenceCount());
        result.walkInOptimizedOrder((RealMatrixChangingVisitor)new DefaultRealMatrixChangingVisitor(){

            public double visit(int row, int column, double value) {
                return matrix.get(row, column);
            }
        });
        return result;
    }

    private static <EVIDENCE extends Locatable> LikelihoodMatrix<EVIDENCE, Allele> combinedLikelihoodMatrix(List<LikelihoodMatrix<EVIDENCE, Allele>> matrices, AlleleList<Allele> alleleList) {
        List reads = matrices.stream().flatMap(m -> m.evidence().stream()).collect(Collectors.toList());
        AlleleLikelihoods combinedLikelihoods = new AlleleLikelihoods(SampleList.singletonSampleList("COMBINED"), alleleList, ImmutableMap.of((Object)"COMBINED", reads));
        int combinedReadIndex = 0;
        LikelihoodMatrix result = combinedLikelihoods.sampleMatrix(0);
        int alleleCount = result.numberOfAlleles();
        for (LikelihoodMatrix<EVIDENCE, Allele> matrix : matrices) {
            int readCount = matrix.evidenceCount();
            for (int r = 0; r < readCount; ++r) {
                for (int a = 0; a < alleleCount; ++a) {
                    result.set(a, combinedReadIndex, matrix.get(a, r));
                }
                ++combinedReadIndex;
            }
        }
        return result;
    }

    private <E> Optional<E> getForNormal(Supplier<E> supplier) {
        return this.hasNormal ? Optional.of(supplier.get()) : Optional.empty();
    }

    private static Map<String, Object> getNegativeLogPopulationAFAnnotation(List<VariantContext> germlineResourceVariants, List<Allele> altAlleles, double afOfAllelesNotInGermlineResource) {
        Optional<VariantContext> germlineVC = germlineResourceVariants.isEmpty() ? Optional.empty() : Optional.of(germlineResourceVariants.get(0));
        double[] populationAlleleFrequencies = SomaticGenotypingEngine.getGermlineAltAlleleFrequencies(altAlleles, germlineVC, afOfAllelesNotInGermlineResource);
        return ImmutableMap.of((Object)"POPAF", (Object)MathUtils.applyToArray(populationAlleleFrequencies, x -> -Math.log10(x)));
    }

    @VisibleForTesting
    static double[] getGermlineAltAlleleFrequencies(List<Allele> altAlleles, Optional<VariantContext> germlineVC, double afOfAllelesNotInGermlineResource) {
        if (germlineVC.isPresent()) {
            List<Double> germlineAltAFs = Mutect2Engine.getAttributeAsDoubleList(germlineVC.get(), "AF", afOfAllelesNotInGermlineResource);
            return altAlleles.stream().mapToDouble(allele -> {
                VariantContext vc = (VariantContext)germlineVC.get();
                OptionalInt germlineAltIndex = IntStream.range(0, vc.getNAlleles() - 1).filter(n -> vc.getAlternateAllele(n).basesMatch(allele)).findAny();
                return germlineAltIndex.isPresent() ? (Double)germlineAltAFs.get(germlineAltIndex.getAsInt()) : afOfAllelesNotInGermlineResource;
            }).toArray();
        }
        return Doubles.toArray(Collections.nCopies(altAlleles.size(), afOfAllelesNotInGermlineResource));
    }
}

