/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.genotyper.afcalc;

import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.Genotype;
import htsjdk.variant.variantcontext.VariantContext;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.special.Gamma;
import org.apache.commons.math3.util.MathArrays;
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeAlleleCounts;
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeCalculationArgumentCollection;
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeLikelihoodCalculator;
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeLikelihoodCalculators;
import org.broadinstitute.hellbender.tools.walkers.genotyper.afcalc.AFCalculationResult;
import org.broadinstitute.hellbender.utils.Dirichlet;
import org.broadinstitute.hellbender.utils.IndexRange;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.dragstr.DragstrParams;

public final class AlleleFrequencyCalculator {
    private static final GenotypeLikelihoodCalculators GL_CALCS = new GenotypeLikelihoodCalculators();
    private static final double THRESHOLD_FOR_ALLELE_COUNT_CONVERGENCE = 0.1;
    private static final int HOM_REF_GENOTYPE_INDEX = 0;
    private final double refPseudocount;
    private final double snpPseudocount;
    private final double indelPseudocount;
    private final int defaultPloidy;

    public AlleleFrequencyCalculator(double refPseudocount, double snpPseudocount, double indelPseudocount, int defaultPloidy) {
        this.refPseudocount = refPseudocount;
        this.snpPseudocount = snpPseudocount;
        this.indelPseudocount = indelPseudocount;
        this.defaultPloidy = defaultPloidy;
    }

    public static AlleleFrequencyCalculator makeCalculator(GenotypeCalculationArgumentCollection genotypeArgs) {
        double refPseudocount = genotypeArgs.snpHeterozygosity / Math.pow(genotypeArgs.heterozygosityStandardDeviation, 2.0);
        double snpPseudocount = genotypeArgs.snpHeterozygosity * refPseudocount;
        double indelPseudocount = genotypeArgs.indelHeterozygosity * refPseudocount;
        return new AlleleFrequencyCalculator(refPseudocount, snpPseudocount, indelPseudocount, genotypeArgs.samplePloidy);
    }

    public static AlleleFrequencyCalculator makeCalculator(DragstrParams dragstrParms, int period, int repeats, int ploidy, double snpHeterozygosity, double scale) {
        double api = dragstrParms.api(period, repeats);
        double log10IndelFreq = api * -0.1;
        double log10RefFreq = MathUtils.log10OneMinusPow10(log10IndelFreq);
        double log10SnpFreq = log10RefFreq + Math.log10(snpHeterozygosity);
        double log10Sum = MathUtils.log10SumLog10(log10RefFreq, log10IndelFreq, log10SnpFreq);
        double log10ScaleUp = Math.log10(scale) - log10Sum;
        double refPseudoCount = Math.pow(10.0, log10ScaleUp + log10RefFreq);
        double indelPseudoCount = Math.pow(10.0, log10ScaleUp + log10IndelFreq);
        double snpPseudoCount = Math.pow(10.0, log10ScaleUp + log10SnpFreq);
        return new AlleleFrequencyCalculator(refPseudoCount, snpPseudoCount, indelPseudoCount, ploidy);
    }

    private static double[] log10NormalizedGenotypePosteriors(Genotype g, GenotypeLikelihoodCalculator glCalc, double[] log10AlleleFrequencies) {
        double[] log10Likelihoods = g.getLikelihoods().getAsVector();
        double[] log10Posteriors = new IndexRange(0, glCalc.genotypeCount()).mapToDouble(genotypeIndex -> {
            GenotypeAlleleCounts gac = glCalc.genotypeAlleleCountsAt(genotypeIndex);
            return gac.log10CombinationCount() + log10Likelihoods[genotypeIndex] + gac.sumOverAlleleIndicesAndCounts((index, count) -> (double)count * log10AlleleFrequencies[index]);
        });
        return MathUtils.normalizeLog10(log10Posteriors);
    }

    private static int[] genotypeIndicesWithOnlyRefAndSpanDel(int ploidy, List<Allele> alleles) {
        GenotypeLikelihoodCalculator glCalc = GL_CALCS.getInstance(ploidy, alleles.size());
        boolean spanningDeletionPresent = alleles.contains(Allele.SPAN_DEL);
        if (!spanningDeletionPresent) {
            return new int[]{0};
        }
        int spanDelIndex = alleles.indexOf(Allele.SPAN_DEL);
        return new IndexRange(0, ploidy + 1).mapToInteger(n -> glCalc.alleleCountsToIndex(0, ploidy - n, spanDelIndex, n));
    }

    public int getPloidy() {
        return this.defaultPloidy;
    }

    public AFCalculationResult calculate(VariantContext vc) {
        return this.calculate(vc, this.defaultPloidy);
    }

    public AFCalculationResult calculate(VariantContext vc, int defaultPloidy) {
        Utils.nonNull(vc, "VariantContext cannot be null");
        int numAlleles = vc.getNAlleles();
        List alleles = vc.getAlleles();
        Utils.validateArg(numAlleles > 1, () -> "VariantContext has only a single reference allele, but getLog10PNonRef requires at least one at all " + vc);
        double[] priorPseudocounts = alleles.stream().mapToDouble(a -> a.isReference() ? this.refPseudocount : (a.length() == vc.getReference().length() ? this.snpPseudocount : this.indelPseudocount)).toArray();
        double[] alleleCounts = new double[numAlleles];
        double flatLog10AlleleFrequency = -MathUtils.log10(numAlleles);
        double[] log10AlleleFrequencies = new IndexRange(0, numAlleles).mapToDouble(n -> flatLog10AlleleFrequency);
        double alleleCountsMaximumDifference = Double.POSITIVE_INFINITY;
        while (alleleCountsMaximumDifference > 0.1) {
            double[] newAlleleCounts = this.effectiveAlleleCounts(vc, log10AlleleFrequencies);
            alleleCountsMaximumDifference = Arrays.stream(MathArrays.ebeSubtract((double[])alleleCounts, (double[])newAlleleCounts)).map(Math::abs).max().getAsDouble();
            alleleCounts = newAlleleCounts;
            double[] posteriorPseudocounts = MathArrays.ebeAdd((double[])priorPseudocounts, (double[])alleleCounts);
            log10AlleleFrequencies = new Dirichlet(posteriorPseudocounts).log10MeanWeights();
        }
        double[] log10POfZeroCountsByAllele = new double[numAlleles];
        double log10PNoVariant = 0.0;
        boolean spanningDeletionPresent = alleles.contains(Allele.SPAN_DEL);
        Int2ObjectArrayMap nonVariantIndicesByPloidy = new Int2ObjectArrayMap();
        List<DoubleArrayList> log10AbsentPosteriors = IntStream.range(0, numAlleles).mapToObj(n -> new DoubleArrayList()).collect(Collectors.toList());
        for (Genotype g : vc.getGenotypes()) {
            if (!g.hasLikelihoods()) continue;
            int ploidy = g.getPloidy() == 0 ? defaultPloidy : g.getPloidy();
            GenotypeLikelihoodCalculator glCalc = GL_CALCS.getInstance(ploidy, numAlleles);
            double[] log10GenotypePosteriors = AlleleFrequencyCalculator.log10NormalizedGenotypePosteriors(g, glCalc, log10AlleleFrequencies);
            if (!spanningDeletionPresent) {
                log10PNoVariant += log10GenotypePosteriors[0];
            } else {
                nonVariantIndicesByPloidy.computeIfAbsent(ploidy, p -> AlleleFrequencyCalculator.genotypeIndicesWithOnlyRefAndSpanDel(p, alleles));
                int[] nonVariantIndices = (int[])nonVariantIndicesByPloidy.get(ploidy);
                double[] nonVariantLog10Posteriors = MathUtils.applyToArray(nonVariantIndices, n -> log10GenotypePosteriors[n]);
                log10PNoVariant += Math.min(0.0, MathUtils.log10SumLog10(nonVariantLog10Posteriors));
            }
            if (numAlleles == 2 && !spanningDeletionPresent) continue;
            log10AbsentPosteriors.forEach(DoubleArrayList::clear);
            for (int genotype = 0; genotype < glCalc.genotypeCount(); ++genotype) {
                double log10GenotypePosterior = log10GenotypePosteriors[genotype];
                glCalc.genotypeAlleleCountsAt(genotype).forEachAbsentAlleleIndex(a -> ((DoubleArrayList)log10AbsentPosteriors.get(a)).add(log10GenotypePosterior), numAlleles);
            }
            double[] log10PNoAllele = log10AbsentPosteriors.stream().mapToDouble(buffer -> MathUtils.log10SumLog10(buffer.toDoubleArray())).map(x -> Math.min(0.0, x)).toArray();
            MathUtils.addToArrayInPlace(log10POfZeroCountsByAllele, log10PNoAllele);
        }
        if (numAlleles == 2 && !spanningDeletionPresent) {
            log10POfZeroCountsByAllele[1] = log10PNoVariant;
        }
        int[] integerAlleleCounts = Arrays.stream(alleleCounts).mapToInt(x -> (int)Math.round(x)).toArray();
        int[] integerAltAlleleCounts = Arrays.copyOfRange(integerAlleleCounts, 1, numAlleles);
        Map<Allele, Double> log10PRefByAllele = IntStream.range(1, numAlleles).boxed().collect(Collectors.toMap(alleles::get, a -> log10POfZeroCountsByAllele[a]));
        return new AFCalculationResult(integerAltAlleleCounts, alleles, log10PNoVariant, log10PRefByAllele);
    }

    public double calculateSingleSampleBiallelicNonRefPosterior(double[] log10GenotypeLikelihoods, boolean returnZeroIfRefIsMax) {
        Utils.nonNull(log10GenotypeLikelihoods);
        if (returnZeroIfRefIsMax && MathUtils.maxElementIndex(log10GenotypeLikelihoods) == 0) {
            return 0.0;
        }
        int ploidy = log10GenotypeLikelihoods.length - 1;
        double[] log10UnnormalizedPosteriors = new IndexRange(0, ploidy + 1).mapToDouble(n -> log10GenotypeLikelihoods[n] + MathUtils.log10BinomialCoefficient(ploidy, n) + MathUtils.logToLog10(Gamma.logGamma((double)((double)n + this.snpPseudocount)) + Gamma.logGamma((double)((double)(ploidy - n) + this.refPseudocount))));
        return returnZeroIfRefIsMax && MathUtils.maxElementIndex(log10UnnormalizedPosteriors) == 0 ? 0.0 : 1.0 - MathUtils.normalizeFromLog10ToLinearSpace(log10UnnormalizedPosteriors)[0];
    }

    private double[] effectiveAlleleCounts(VariantContext vc, double[] log10AlleleFrequencies) {
        int numAlleles = vc.getNAlleles();
        Utils.validateArg(numAlleles == log10AlleleFrequencies.length, "number of alleles inconsistent");
        double[] log10Result = new double[numAlleles];
        Arrays.fill(log10Result, Double.NEGATIVE_INFINITY);
        for (Genotype g : vc.getGenotypes()) {
            if (!g.hasLikelihoods()) continue;
            GenotypeLikelihoodCalculator glCalc = GL_CALCS.getInstance(g.getPloidy(), numAlleles);
            double[] log10GenotypePosteriors = AlleleFrequencyCalculator.log10NormalizedGenotypePosteriors(g, glCalc, log10AlleleFrequencies);
            new IndexRange(0, glCalc.genotypeCount()).forEach(genotypeIndex -> glCalc.genotypeAlleleCountsAt(genotypeIndex).forEachAlleleIndexAndCount((alleleIndex, count) -> {
                log10Result[alleleIndex] = MathUtils.log10SumLog10(log10Result[alleleIndex], log10GenotypePosteriors[genotypeIndex] + MathUtils.log10(count));
            }));
        }
        return MathUtils.applyToArrayInPlace(log10Result, x -> Math.pow(10.0, x));
    }
}

