/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.em.models;

import elki.clustering.em.models.EMClusterModel;
import elki.data.NumberVector;
import elki.data.model.EMModel;
import elki.math.MathUtil;
import elki.math.linearalgebra.VMath;
import net.jafama.FastMath;

public class TextbookSphericalGaussianModel
implements EMClusterModel<NumberVector, EMModel> {
    double[] mean;
    double variance;
    double[] nmea;
    double logNorm;
    double logNormDet;
    double weight;
    double wsum;
    double priorvar;

    public TextbookSphericalGaussianModel(double weight, double[] mean) {
        this(weight, mean, 1.0);
    }

    public TextbookSphericalGaussianModel(double weight, double[] mean, double var) {
        this.weight = weight;
        this.mean = mean;
        this.logNorm = MathUtil.LOGTWOPI * (double)mean.length;
        this.logNormDet = FastMath.log((double)weight) - 0.5 * this.logNorm;
        this.nmea = new double[mean.length];
        this.priorvar = this.variance = var > 0.0 ? var : 1.0E-10;
        this.wsum = 0.0;
    }

    @Override
    public void beginEStep() {
        this.wsum = 0.0;
        VMath.clear((double[])this.mean);
        this.variance = 0.0;
    }

    @Override
    public void updateE(NumberVector vec, double wei) {
        int dim = this.mean.length;
        assert (vec.getDimensionality() == dim);
        assert (wei >= 0.0 && wei < Double.POSITIVE_INFINITY) : wei;
        int i = 0;
        while (i < this.mean.length) {
            double vi = vec.doubleValue(i);
            double vi_wei = vi * wei;
            int n = i++;
            this.mean[n] = this.mean[n] + vi_wei;
            this.variance += vi_wei * vi;
        }
        this.wsum += wei;
    }

    @Override
    public void finalizeEStep(double weight, double prior) {
        int dim = this.mean.length;
        this.weight = weight;
        double logDet = 0.0;
        double f = this.wsum > Double.MIN_NORMAL && this.wsum < Double.POSITIVE_INFINITY ? 1.0 / this.wsum : 1.0;
        int i = 0;
        while (i < dim) {
            int n = i++;
            this.mean[n] = this.mean[n] * f;
        }
        if (prior > 0.0) {
            double nu = dim + 2;
            double f2 = 1.0 / (this.wsum + prior * (nu + (double)dim + 2.0));
            double newvar = 0.0;
            for (int i2 = 0; i2 < dim; ++i2) {
                newvar += (this.variance - this.mean[i2] * this.mean[i2] * this.wsum + prior * this.priorvar) * f2;
            }
            this.variance = newvar;
            logDet = FastMath.log((double)this.variance);
        } else if (this.wsum > 0.0) {
            double newvar = 0.0;
            double wf = this.wsum > Double.MIN_NORMAL && this.wsum < Double.POSITIVE_INFINITY ? 1.0 / (this.wsum * (double)dim) : 1.0 / (double)dim;
            for (int i3 = 0; i3 < dim; ++i3) {
                newvar += this.variance * wf - this.mean[i3] * this.mean[i3];
            }
            this.variance = newvar / (double)dim;
            logDet = FastMath.log((double)this.variance) * (double)dim;
        }
        this.logNormDet = FastMath.log((double)weight) - 0.5 * (this.logNorm + logDet);
    }

    public double mahalanobisDistance(double[] vec) {
        double agg = 0.0;
        for (int i = 0; i < vec.length; ++i) {
            double diff = vec[i] - this.mean[i];
            agg += diff / this.variance * diff;
        }
        return agg;
    }

    public double mahalanobisDistance(NumberVector vec) {
        double agg = 0.0;
        for (int i = 0; i < this.mean.length; ++i) {
            double diff = vec.doubleValue(i) - this.mean[i];
            agg += diff / this.variance * diff;
        }
        return agg;
    }

    @Override
    public double estimateLogDensity(NumberVector vec) {
        return -0.5 * this.mahalanobisDistance(vec) + this.logNormDet;
    }

    @Override
    public double getWeight() {
        return this.weight;
    }

    @Override
    public void setWeight(double weight) {
        this.weight = weight;
    }

    @Override
    public EMModel finalizeCluster() {
        return new EMModel(this.mean, VMath.timesEquals((double[][])VMath.identity((int)this.nmea.length, (int)this.nmea.length), (double)this.variance));
    }
}

