/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.mutect.clustering;

import java.util.List;
import org.apache.commons.math3.special.Gamma;
import org.broadinstitute.hellbender.tools.walkers.mutect.SomaticLikelihoodsEngine;
import org.broadinstitute.hellbender.tools.walkers.mutect.clustering.AlleleFractionCluster;
import org.broadinstitute.hellbender.tools.walkers.mutect.clustering.Datum;
import org.broadinstitute.hellbender.tools.walkers.readorientation.BetaDistributionShape;
import org.broadinstitute.hellbender.tools.walkers.validation.basicshortmutpileup.BetaBinomialDistribution;

public class BetaBinomialCluster
implements AlleleFractionCluster {
    private static final double RATE = 0.01;
    private static final double MAX_RATE = 0.1;
    private static final int NUM_EPOCHS = 10;
    BetaDistributionShape betaDistributionShape;

    public BetaBinomialCluster(BetaDistributionShape betaDistributionShape) {
        this.betaDistributionShape = betaDistributionShape;
    }

    @Override
    public double correctedLogLikelihood(Datum datum) {
        return BetaBinomialCluster.correctedLogLikelihood(datum, this.betaDistributionShape);
    }

    @Override
    public double logLikelihood(int totalCount, int altCount) {
        return new BetaBinomialDistribution(null, this.betaDistributionShape.getAlpha(), this.betaDistributionShape.getBeta(), totalCount).logProbability(altCount);
    }

    public static double correctedLogLikelihood(Datum datum, BetaDistributionShape betaDistributionShape) {
        int altCount = datum.getAltCount();
        int refCount = datum.getTotalCount() - altCount;
        return datum.getTumorLogOdds() + BetaBinomialCluster.logOddsCorrection(BetaDistributionShape.FLAT_BETA, betaDistributionShape, altCount, refCount);
    }

    @Override
    public void learn(List<Datum> data, double[] responsibilities) {
        double alpha = this.betaDistributionShape.getAlpha();
        double beta = this.betaDistributionShape.getBeta();
        for (int epoch = 0; epoch < 10; ++epoch) {
            for (int n = 0; n < data.size(); ++n) {
                Datum datum = data.get(n);
                int alt = datum.getAltCount();
                int ref = datum.getTotalCount() - alt;
                double digammaOfTotalPlusAlphaPlusBeta = Gamma.digamma((double)((double)datum.getTotalCount() + alpha + beta));
                double digammaOfAlphaPlusBeta = Gamma.digamma((double)(alpha + beta));
                double alphaGradient = Gamma.digamma((double)(alpha + (double)alt)) - digammaOfTotalPlusAlphaPlusBeta - Gamma.digamma((double)alpha) + digammaOfAlphaPlusBeta;
                double betaGradient = Gamma.digamma((double)(beta + (double)ref)) - digammaOfTotalPlusAlphaPlusBeta - Gamma.digamma((double)beta) + digammaOfAlphaPlusBeta;
                alpha = Math.max(alpha + 0.01 * alphaGradient * responsibilities[n], 1.0);
                beta = Math.max(beta + 0.01 * betaGradient * responsibilities[n], 0.5);
            }
        }
        this.betaDistributionShape = new BetaDistributionShape(alpha, beta);
    }

    private static double logOddsCorrection(BetaDistributionShape originalBeta, BetaDistributionShape newBeta, int altCount, int refCount) {
        return BetaBinomialCluster.g(newBeta.getAlpha(), newBeta.getBeta()) - BetaBinomialCluster.g(newBeta.getAlpha() + (double)altCount, newBeta.getBeta() + (double)refCount) - BetaBinomialCluster.g(originalBeta.getAlpha(), originalBeta.getBeta()) + BetaBinomialCluster.g(originalBeta.getAlpha() + (double)altCount, originalBeta.getBeta() + (double)refCount);
    }

    private static double g(double ... omega) {
        return SomaticLikelihoodsEngine.logDirichletNormalization(omega);
    }

    @Override
    public String toString() {
        return String.format("alpha = %.2f, beta = %.2f", this.betaDistributionShape.getAlpha(), this.betaDistributionShape.getBeta());
    }
}

