/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.mutect;

import com.google.common.annotations.VisibleForTesting;
import java.util.Arrays;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.special.Gamma;
import org.apache.commons.math3.util.MathArrays;
import org.broadinstitute.hellbender.utils.Dirichlet;
import org.broadinstitute.hellbender.utils.IndexRange;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.NaturalLogUtils;
import org.broadinstitute.hellbender.utils.Utils;

public class SomaticLikelihoodsEngine {
    public static final double CONVERGENCE_THRESHOLD = 0.001;
    private static double NEGLIGIBLE_RESPONSIBILITY = 1.0E-10;

    public static double[] alleleFractionsPosterior(RealMatrix logLikelihoods, double[] priorPseudocounts) {
        int numberOfAlleles = logLikelihoods.getRowDimension();
        Utils.validateArg(numberOfAlleles == priorPseudocounts.length, "Must have one pseudocount per allele.");
        double[] dirichletPosterior = new IndexRange(0, numberOfAlleles).mapToDouble(n -> 1.0);
        boolean converged = false;
        while (!converged) {
            double[] alleleCounts = SomaticLikelihoodsEngine.getEffectiveCounts(logLikelihoods, dirichletPosterior);
            double[] newDirichletPosterior = MathArrays.ebeAdd((double[])alleleCounts, (double[])priorPseudocounts);
            converged = MathArrays.distance1((double[])dirichletPosterior, (double[])newDirichletPosterior) / MathUtils.sum(newDirichletPosterior) < 0.001;
            dirichletPosterior = newDirichletPosterior;
        }
        return dirichletPosterior;
    }

    @VisibleForTesting
    protected static double[] getEffectiveCounts(RealMatrix logLikelihoods, double[] dirichletPrior) {
        double[] effectiveLogWeights = new Dirichlet(dirichletPrior).effectiveLogMultinomialWeights();
        return MathUtils.sumArrayFunction(0, logLikelihoods.getColumnDimension(), read -> NaturalLogUtils.posteriors(effectiveLogWeights, logLikelihoods.getColumn(read)));
    }

    public static double logEvidence(RealMatrix logLikelihoods, double[] priorPseudocounts) {
        int numberOfAlleles = logLikelihoods.getRowDimension();
        Utils.validateArg(numberOfAlleles == priorPseudocounts.length, "Must have one pseudocount per allele.");
        double[] alleleFractionsPosterior = SomaticLikelihoodsEngine.alleleFractionsPosterior(logLikelihoods, priorPseudocounts);
        double priorContribution = SomaticLikelihoodsEngine.logDirichletNormalization(priorPseudocounts);
        double posteriorContribution = -SomaticLikelihoodsEngine.logDirichletNormalization(alleleFractionsPosterior);
        double[] logAlleleFractions = new Dirichlet(alleleFractionsPosterior).effectiveLogMultinomialWeights();
        double likelihoodsAndEntropyContribution = new IndexRange(0, logLikelihoods.getColumnDimension()).sum(r -> {
            double[] logLikelihoodsForRead = logLikelihoods.getColumn(r);
            double[] responsibilities = NaturalLogUtils.posteriors(logAlleleFractions, logLikelihoodsForRead);
            double entropyContribution = Arrays.stream(responsibilities).map(SomaticLikelihoodsEngine::xLogx).sum();
            return SomaticLikelihoodsEngine.likelihoodsContribution(logLikelihoodsForRead, responsibilities) - entropyContribution;
        });
        return priorContribution + posteriorContribution + likelihoodsAndEntropyContribution;
    }

    private static double likelihoodsContribution(double[] logLikelihoodsForRead, double[] responsibilities) {
        double result = 0.0;
        for (int n = 0; n < logLikelihoodsForRead.length; ++n) {
            result += responsibilities[n] < NEGLIGIBLE_RESPONSIBILITY ? 0.0 : logLikelihoodsForRead[n] * responsibilities[n];
        }
        return result;
    }

    private static double xLogx(double x) {
        return x < 1.0E-8 ? 0.0 : x * Math.log(x);
    }

    public static double logDirichletNormalization(double ... dirichletParams) {
        double logNumerator = Gamma.logGamma((double)MathUtils.sum(dirichletParams));
        double logDenominator = MathUtils.sum(MathUtils.applyToArray(dirichletParams, Gamma::logGamma));
        return logNumerator - logDenominator;
    }
}

