/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.em.models;

import elki.clustering.em.models.BetulaClusterModel;
import elki.data.NumberVector;
import elki.data.model.EMModel;
import elki.index.tree.betula.features.ClusterFeature;
import elki.logging.Logging;
import elki.math.MathUtil;
import elki.math.linearalgebra.CholeskyDecomposition;
import elki.math.linearalgebra.VMath;
import net.jafama.FastMath;

public class MultivariateGaussianModel
implements BetulaClusterModel {
    private static final Logging LOG = Logging.getLogger(MultivariateGaussianModel.class);
    private static final double SINGULARITY_CHEAT = 1.0E-10;
    double[] mean;
    double[][] covariance;
    CholeskyDecomposition chol;
    double[] nmea;
    double logNorm;
    double logNormDet;
    double weight;
    double wsum;
    double[][] priormatrix;

    public MultivariateGaussianModel(double weight, double[] mean) {
        this(weight, mean, null);
    }

    public MultivariateGaussianModel(double weight, double[] mean, double[][] covariance) {
        this.weight = weight;
        this.mean = mean;
        this.logNorm = MathUtil.LOGTWOPI * (double)mean.length;
        this.nmea = new double[mean.length];
        this.covariance = covariance != null ? VMath.copy((double[][])covariance) : VMath.identity((int)mean.length, (int)mean.length);
        this.priormatrix = (double[][])(covariance != null ? covariance : null);
        this.wsum = 0.0;
        this.chol = MultivariateGaussianModel.updateCholesky(this.covariance, null);
        this.logNormDet = FastMath.log((double)weight) - 0.5 * this.logNorm - MultivariateGaussianModel.getHalfLogDeterminant(this.chol);
    }

    @Override
    public void beginEStep() {
        this.wsum = 0.0;
        VMath.clear((double[])this.mean);
        VMath.clear((double[][])this.covariance);
    }

    @Override
    public void updateE(NumberVector vec, double wei) {
        int i;
        assert (vec.getDimensionality() == this.mean.length);
        assert (wei >= 0.0 && wei < Double.POSITIVE_INFINITY) : wei;
        if (wei < Double.MIN_NORMAL) {
            return;
        }
        int dim = this.mean.length;
        double nwsum = this.wsum + wei;
        double f = wei / nwsum;
        for (i = 0; i < dim; ++i) {
            this.nmea[i] = this.mean[i] + (vec.doubleValue(i) - this.mean[i]) * f;
        }
        for (i = 0; i < dim; ++i) {
            double vi = vec.doubleValue(i);
            double delta_i = vi - this.nmea[i];
            double[] cov_i = this.covariance[i];
            for (int j = 0; j < i; ++j) {
                int n = j;
                cov_i[n] = cov_i[n] + delta_i * (vec.doubleValue(j) - this.mean[j]) * wei;
            }
            int n = i;
            cov_i[n] = cov_i[n] + delta_i * (vi - this.mean[i]) * wei;
        }
        this.wsum = nwsum;
        System.arraycopy(this.nmea, 0, this.mean, 0, this.nmea.length);
    }

    @Override
    public void finalizeEStep(double weight, double prior) {
        double f;
        int dim = this.covariance.length;
        this.weight = weight;
        double d = f = this.wsum > Double.MIN_NORMAL && this.wsum < Double.POSITIVE_INFINITY ? 1.0 / this.wsum : 1.0;
        if (prior > 0.0 && this.priormatrix != null) {
            double nu = (double)dim + 2.0;
            double f2 = 1.0 / (this.wsum + prior * (nu + (double)dim + 2.0));
            for (int i = 0; i < dim; ++i) {
                double[] row_i = this.covariance[i];
                double[] pri_i = this.priormatrix[i];
                for (int j = 0; j < i; ++j) {
                    this.covariance[j][i] = row_i[j] = (row_i[j] + prior * pri_i[j]) * f2;
                }
                row_i[i] = (row_i[i] + prior * pri_i[i]) * f2;
            }
        } else {
            int i = 0;
            while (i < dim) {
                double[] row_i = this.covariance[i];
                for (int j = 0; j < i; ++j) {
                    int n = j;
                    double d2 = row_i[n] * f;
                    row_i[n] = d2;
                    this.covariance[j][i] = d2;
                }
                int n = i++;
                row_i[n] = row_i[n] * f;
            }
        }
        this.chol = MultivariateGaussianModel.updateCholesky(this.covariance, null);
        this.logNormDet = FastMath.log((double)weight) - 0.5 * this.logNorm - MultivariateGaussianModel.getHalfLogDeterminant(this.chol);
        if (prior > 0.0 && this.priormatrix == null) {
            this.priormatrix = VMath.copy((double[][])this.covariance);
        }
    }

    protected static CholeskyDecomposition updateCholesky(double[][] covariance, CholeskyDecomposition prev) {
        int i;
        CholeskyDecomposition nextchol = new CholeskyDecomposition(covariance);
        if (nextchol.isSPD()) {
            return nextchol;
        }
        double s = 0.0;
        for (i = 0; i < covariance.length; ++i) {
            s += covariance[i][i];
        }
        s = s > 1.0E-100 ? s * 1.0E-10 / (double)covariance.length : 1.0E-50;
        i = 0;
        while (i < covariance.length) {
            double[] dArray = covariance[i];
            int n = i++;
            dArray[n] = dArray[n] + s;
        }
        nextchol = new CholeskyDecomposition(covariance);
        if (!nextchol.isSPD()) {
            LOG.warning((CharSequence)"A cluster has degenerated, likely due to lack of variance in a subset of the data or too extreme magnitude differences.\nThe algorithm will likely stop without converging, and fail to produce a good fit.");
            return prev != null ? prev : nextchol;
        }
        return nextchol;
    }

    protected static double getHalfLogDeterminant(CholeskyDecomposition chol) {
        double[][] l = chol.getL();
        double logdet = FastMath.log((double)l[0][0]);
        for (int i = 1; i < l.length; ++i) {
            logdet += FastMath.log((double)l[i][i]);
        }
        return logdet;
    }

    public double mahalanobisDistance(double[] vec) {
        return VMath.squareSum((double[])this.chol.solveLInplace(VMath.minusEquals((double[])((double[])vec.clone()), (double[])this.mean)));
    }

    public double mahalanobisDistance(NumberVector vec) {
        return VMath.squareSum((double[])this.chol.solveLInplace(VMath.minusEquals((double[])vec.toArray(), (double[])this.mean)));
    }

    @Override
    public double estimateLogDensity(NumberVector vec) {
        return -0.5 * this.mahalanobisDistance(vec) + this.logNormDet;
    }

    @Override
    public double getWeight() {
        return this.weight;
    }

    @Override
    public void setWeight(double weight) {
        this.weight = weight;
    }

    @Override
    public EMModel finalizeCluster() {
        return new EMModel(this.mean, this.covariance);
    }

    @Override
    public double estimateLogDensity(ClusterFeature cf) {
        int i;
        double[][] combinedCov = cf.covariance();
        double[] delta = (double[])this.mean.clone();
        for (i = 0; i < this.mean.length; ++i) {
            int n = i;
            delta[n] = delta[n] - cf.centroid(i);
        }
        for (i = 0; i < this.covariance.length; ++i) {
            for (int j = 0; j <= i; ++j) {
                double[] dArray = combinedCov[i];
                int n = j;
                double d = dArray[n] + this.covariance[i][j];
                dArray[n] = d;
                combinedCov[j][i] = d;
            }
        }
        CholeskyDecomposition cchol = MultivariateGaussianModel.updateCholesky(combinedCov, this.chol);
        double clogNormDet = FastMath.log((double)((double)cf.getWeight() + this.wsum)) - 0.5 * this.logNorm - MultivariateGaussianModel.getHalfLogDeterminant(cchol);
        return -0.5 * VMath.squareSum((double[])this.chol.solveLInplace(delta)) + clogNormDet;
    }

    @Override
    public void updateE(ClusterFeature cf, double wei) {
        assert (cf.getDimensionality() == this.mean.length);
        assert (wei >= 0.0 && wei < Double.POSITIVE_INFINITY) : wei;
        if (wei < Double.MIN_NORMAL) {
            return;
        }
        int dim = this.mean.length;
        double nwsum = this.wsum + wei;
        double f = wei / nwsum;
        double[][] cfcov = VMath.timesEquals((double[][])cf.covariance(), (double)wei);
        for (int i = 0; i < dim; ++i) {
            double delta = cf.centroid(i) - this.mean[i];
            this.nmea[i] = this.mean[i] + delta * f;
            for (int j = 0; j <= i; ++j) {
                double[] dArray = this.covariance[i];
                int n = j;
                dArray[n] = dArray[n] + (cfcov[i][j] + wei * (delta * (cf.centroid(j) - this.nmea[j])));
            }
        }
        this.wsum = nwsum;
        System.arraycopy(this.nmea, 0, this.mean, 0, this.nmea.length);
    }
}

