/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.vqsr;

import Jama.Matrix;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.tools.walkers.vqsr.MultivariateGaussian;
import org.broadinstitute.hellbender.tools.walkers.vqsr.VariantDatum;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.Utils;

class GaussianMixtureModel {
    protected static final Logger logger = LogManager.getLogger(GaussianMixtureModel.class);
    private final List<MultivariateGaussian> gaussians;
    private final double shrinkage;
    private final double dirichletParameter;
    private final double priorCounts;
    private final double[] empiricalMu;
    private final Matrix empiricalSigma;
    public boolean isModelReadyForEvaluation;
    public boolean failedToConverge = false;

    public GaussianMixtureModel(int numGaussians, int numVariantData, int numAnnotations, double shrinkage, double dirichletParameter, double priorCounts) {
        this.gaussians = new ArrayList<MultivariateGaussian>(numGaussians);
        for (int iii = 0; iii < numGaussians; ++iii) {
            MultivariateGaussian gaussian = new MultivariateGaussian(numVariantData, numAnnotations);
            this.gaussians.add(gaussian);
        }
        this.shrinkage = shrinkage;
        this.dirichletParameter = dirichletParameter;
        this.priorCounts = priorCounts;
        this.empiricalMu = new double[numAnnotations];
        this.empiricalSigma = new Matrix(numAnnotations, numAnnotations);
        this.isModelReadyForEvaluation = false;
        Arrays.fill(this.empiricalMu, 0.0);
        this.empiricalSigma.setMatrix(0, this.empiricalMu.length - 1, 0, this.empiricalMu.length - 1, Matrix.identity((int)this.empiricalMu.length, (int)this.empiricalMu.length).times(200.0).inverse());
    }

    protected GaussianMixtureModel(List<MultivariateGaussian> gaussians, double shrinkage, double dirichletParameter, double priorCounts) {
        this.gaussians = gaussians;
        int numAnnotations = gaussians.get((int)0).mu.length;
        this.shrinkage = shrinkage;
        this.dirichletParameter = dirichletParameter;
        this.priorCounts = priorCounts;
        this.empiricalMu = new double[numAnnotations];
        this.empiricalSigma = new Matrix(numAnnotations, numAnnotations);
        this.isModelReadyForEvaluation = false;
        Arrays.fill(this.empiricalMu, 0.0);
        this.empiricalSigma.setMatrix(0, this.empiricalMu.length - 1, 0, this.empiricalMu.length - 1, Matrix.identity((int)this.empiricalMu.length, (int)this.empiricalMu.length).times(200.0).inverse());
    }

    public void initializeRandomModel(List<VariantDatum> data, int numKMeansIterations) {
        for (MultivariateGaussian gaussian : this.gaussians) {
            gaussian.initializeRandomMu(Utils.getRandomGenerator());
        }
        logger.info("Initializing model with " + numKMeansIterations + " k-means iterations...");
        this.initializeMeansUsingKMeans(data, numKMeansIterations);
        for (MultivariateGaussian gaussian : this.gaussians) {
            gaussian.pMixtureLog10 = Math.log10(1.0 / (double)this.gaussians.size());
            gaussian.sumProb = 1.0 / (double)this.gaussians.size();
            gaussian.initializeRandomSigma(Utils.getRandomGenerator());
            gaussian.hyperParameter_a = this.priorCounts;
            gaussian.hyperParameter_b = this.shrinkage;
            gaussian.hyperParameter_lambda = this.dirichletParameter;
        }
    }

    private void initializeMeansUsingKMeans(List<VariantDatum> data, int numIterations) {
        int ttt = 0;
        while (ttt++ < numIterations) {
            for (VariantDatum datum : data) {
                MultivariateGaussian minGaussian;
                double minDistance = Double.MAX_VALUE;
                datum.assignment = minGaussian = null;
                for (MultivariateGaussian gaussian : this.gaussians) {
                    double dist = gaussian.calculateDistanceFromMeanSquared(datum);
                    if (!(dist < minDistance)) continue;
                    minDistance = dist;
                    minGaussian = gaussian;
                }
                datum.assignment = minGaussian;
            }
            for (MultivariateGaussian gaussian : this.gaussians) {
                gaussian.zeroOutMu();
                int numAssigned = 0;
                for (VariantDatum datum : data) {
                    if (!datum.assignment.equals(gaussian)) continue;
                    ++numAssigned;
                    gaussian.incrementMu(datum);
                }
                if (numAssigned != 0) {
                    gaussian.divideEqualsMu(numAssigned);
                    continue;
                }
                gaussian.initializeRandomMu(Utils.getRandomGenerator());
            }
        }
    }

    public void expectationStep(List<VariantDatum> data) {
        for (MultivariateGaussian gaussian : this.gaussians) {
            gaussian.precomputeDenominatorForVariationalBayes(this.getSumHyperParameterLambda());
        }
        for (VariantDatum datum : data) {
            double[] pVarInGaussianLog10 = this.gaussians.stream().mapToDouble(g -> g.evaluateDatumLog10(datum)).toArray();
            double[] pVarInGaussianNormalized = MathUtils.normalizeLog10DeleteMePlease(pVarInGaussianLog10, false);
            int gaussianIndex = 0;
            for (MultivariateGaussian gaussian : this.gaussians) {
                gaussian.assignPVarInGaussian(pVarInGaussianNormalized[gaussianIndex++]);
            }
        }
    }

    public void maximizationStep(List<VariantDatum> data) {
        this.gaussians.forEach(g -> g.maximizeGaussian(data, this.empiricalMu, this.empiricalSigma, this.shrinkage, this.dirichletParameter, this.priorCounts));
    }

    private double getSumHyperParameterLambda() {
        return this.gaussians.stream().mapToDouble(g -> g.hyperParameter_lambda).sum();
    }

    public void evaluateFinalModelParameters(List<VariantDatum> data) {
        this.gaussians.forEach(g -> g.evaluateFinalModelParameters(data));
        this.normalizePMixtureLog10();
    }

    public double normalizePMixtureLog10() {
        double sumDiff = 0.0;
        double sumPK = this.gaussians.stream().mapToDouble(g -> g.sumProb).sum();
        double log10SumPK = Math.log10(sumPK);
        double[] pGaussianLog10 = this.gaussians.stream().mapToDouble(g -> Math.log10(g.sumProb) - log10SumPK).toArray();
        MathUtils.normalizeLog10DeleteMePlease(pGaussianLog10, true);
        int gaussianIndex = 0;
        for (MultivariateGaussian gaussian : this.gaussians) {
            sumDiff += Math.abs(pGaussianLog10[gaussianIndex] - gaussian.pMixtureLog10);
            gaussian.pMixtureLog10 = pGaussianLog10[gaussianIndex++];
        }
        return sumDiff;
    }

    public void precomputeDenominatorForEvaluation() {
        for (MultivariateGaussian gaussian : this.gaussians) {
            gaussian.precomputeDenominatorForEvaluation();
        }
        this.isModelReadyForEvaluation = true;
    }

    private double nanTolerantLog10SumLog10(double[] values) {
        for (double value : values) {
            if (!Double.isNaN(value)) continue;
            return Double.NaN;
        }
        return MathUtils.log10sumLog10(values);
    }

    public double evaluateDatum(VariantDatum datum) {
        for (boolean isNull : datum.isNull) {
            if (!isNull) continue;
            return this.evaluateDatumMarginalized(datum);
        }
        double[] pVarInGaussianLog10 = new double[this.gaussians.size()];
        int gaussianIndex = 0;
        for (MultivariateGaussian gaussian : this.gaussians) {
            pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + gaussian.evaluateDatumLog10(datum);
        }
        return this.nanTolerantLog10SumLog10(pVarInGaussianLog10);
    }

    public Double evaluateDatumInOneDimension(VariantDatum datum, int iii) {
        if (datum.isNull[iii]) {
            return null;
        }
        double[] pVarInGaussianLog10 = new double[this.gaussians.size()];
        int gaussianIndex = 0;
        for (MultivariateGaussian gaussian : this.gaussians) {
            pVarInGaussianLog10[gaussianIndex] = gaussian.pMixtureLog10;
            if (gaussian.pMixtureLog10 != Double.NEGATIVE_INFINITY) {
                int n = gaussianIndex;
                pVarInGaussianLog10[n] = pVarInGaussianLog10[n] + MathUtils.normalDistributionLog10(gaussian.mu[iii], gaussian.sigma.get(iii, iii), datum.annotations[iii]);
            }
            ++gaussianIndex;
        }
        return this.nanTolerantLog10SumLog10(pVarInGaussianLog10);
    }

    public double evaluateDatumMarginalized(VariantDatum datum) {
        int numRandomDraws = 0;
        double sumPVarInGaussian = 0.0;
        int numIterPerMissingAnnotation = 20;
        double[] pVarInGaussianLog10 = new double[this.gaussians.size()];
        for (int iii = 0; iii < datum.annotations.length; ++iii) {
            if (!datum.isNull[iii]) continue;
            for (int ttt = 0; ttt < 20; ++ttt) {
                datum.annotations[iii] = Utils.getRandomGenerator().nextGaussian();
                int gaussianIndex = 0;
                for (MultivariateGaussian gaussian : this.gaussians) {
                    pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + gaussian.evaluateDatumLog10(datum);
                }
                sumPVarInGaussian += Math.pow(10.0, this.nanTolerantLog10SumLog10(pVarInGaussianLog10));
                ++numRandomDraws;
            }
        }
        return Math.log10(sumPVarInGaussian / (double)numRandomDraws);
    }

    protected List<MultivariateGaussian> getModelGaussians() {
        return Collections.unmodifiableList(this.gaussians);
    }

    protected int getNumAnnotations() {
        return this.empiricalMu.length;
    }
}

