/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.em.models;

import elki.clustering.em.models.BetulaClusterModel;
import elki.data.NumberVector;
import elki.data.model.EMModel;
import elki.index.tree.betula.features.ClusterFeature;
import elki.math.MathUtil;
import elki.math.linearalgebra.VMath;
import java.util.Arrays;
import net.jafama.FastMath;

public class DiagonalGaussianModel
implements BetulaClusterModel {
    private static final double SINGULARITY_CHEAT = 1.0E-10;
    double[] mean;
    double[] variances;
    double[] nmea;
    double logNorm;
    double logNormDet;
    double weight;
    double wsum;
    double[] priordiag;

    public DiagonalGaussianModel(double weight, double[] mean) {
        this(weight, mean, null);
    }

    public DiagonalGaussianModel(double weight, double[] mean, double[] vars) {
        int i;
        this.weight = weight;
        int dim = mean.length;
        this.mean = mean;
        this.logNorm = MathUtil.LOGTWOPI * (double)mean.length;
        this.logNormDet = FastMath.log((double)weight) - 0.5 * this.logNorm;
        this.nmea = new double[dim];
        if (vars == null) {
            this.variances = new double[dim];
            Arrays.fill(this.variances, 1.0);
        } else {
            this.variances = new double[dim];
            for (i = 0; i < dim; ++i) {
                this.variances[i] = MathUtil.max((double)vars[i], (double)1.0E-10);
            }
            this.priordiag = vars;
        }
        for (i = 0; i < this.variances.length; ++i) {
            this.variances[i] = Math.max(this.variances[i], 1.0E-10);
        }
        this.wsum = 0.0;
    }

    @Override
    public void beginEStep() {
        this.wsum = 0.0;
        VMath.clear((double[])this.mean);
        VMath.clear((double[])this.variances);
    }

    @Override
    public void updateE(NumberVector vec, double wei) {
        int i;
        assert (vec.getDimensionality() == this.mean.length);
        assert (wei >= 0.0 && wei < Double.POSITIVE_INFINITY) : wei;
        if (wei < Double.MIN_NORMAL) {
            return;
        }
        double nwsum = this.wsum + wei;
        double f = wei / nwsum;
        for (i = 0; i < this.mean.length; ++i) {
            this.nmea[i] = this.mean[i] + (vec.doubleValue(i) - this.mean[i]) * f;
        }
        for (i = 0; i < this.mean.length; ++i) {
            double vi = vec.doubleValue(i);
            int n = i;
            this.variances[n] = this.variances[n] + (vi - this.nmea[i]) * (vi - this.mean[i]) * wei;
        }
        this.wsum = nwsum;
        System.arraycopy(this.nmea, 0, this.mean, 0, this.nmea.length);
    }

    @Override
    public void finalizeEStep(double weight, double prior) {
        int dim = this.variances.length;
        this.weight = weight;
        double logDet = 0.0;
        if (prior > 0.0 && this.priordiag != null) {
            double nu = (double)dim + 2.0;
            double f2 = 1.0 / (this.wsum + prior * (nu + (double)dim + 2.0));
            for (int i = 0; i < dim; ++i) {
                double v = this.variances[i] + prior * this.priordiag[i];
                this.variances[i] = v > 0.0 ? v * f2 : 1.0E-10;
                logDet += FastMath.log((double)this.variances[i]);
            }
        } else {
            double f = this.wsum > 0.0 ? 1.0 / this.wsum : 1.0;
            for (int i = 0; i < dim; ++i) {
                double v = this.variances[i];
                this.variances[i] = v > 0.0 ? v * f : 1.0E-10;
                logDet += FastMath.log((double)this.variances[i]);
            }
        }
        this.logNormDet = FastMath.log((double)weight) - 0.5 * (this.logNorm + logDet);
        if (prior > 0.0 && this.priordiag == null) {
            this.priordiag = VMath.copy((double[])this.variances);
        }
    }

    public double mahalanobisDistance(double[] vec) {
        double agg = 0.0;
        for (int i = 0; i < this.mean.length; ++i) {
            double diff = vec[i] - this.mean[i];
            double v = this.variances[i];
            agg += diff / v * diff;
        }
        return agg;
    }

    public double mahalanobisDistance(NumberVector vec) {
        double agg = 0.0;
        for (int i = 0; i < this.mean.length; ++i) {
            double diff = vec.doubleValue(i) - this.mean[i];
            double v = this.variances[i];
            agg += diff / v * diff;
        }
        return agg;
    }

    @Override
    public double estimateLogDensity(NumberVector vec) {
        return -0.5 * this.mahalanobisDistance(vec) + this.logNormDet;
    }

    @Override
    public double getWeight() {
        return this.weight;
    }

    @Override
    public void setWeight(double weight) {
        this.weight = weight;
    }

    @Override
    public EMModel finalizeCluster() {
        return new EMModel(this.mean, VMath.diagonal((double[])this.variances));
    }

    @Override
    public double estimateLogDensity(ClusterFeature cf) {
        int i;
        double agg = this.logNorm;
        for (i = 0; i < this.mean.length; ++i) {
            double diff = cf.centroid(i) - this.mean[i];
            agg += diff / (this.variances[i] + cf.variance(i)) * diff;
        }
        for (i = 0; i < this.mean.length; ++i) {
            agg += FastMath.log((double)(this.variances[i] + cf.variance(i)));
        }
        return -0.5 * agg;
    }

    @Override
    public void updateE(ClusterFeature cf, double wei) {
        int i;
        assert (cf.getDimensionality() == this.mean.length);
        double nwsum = this.wsum + wei;
        for (i = 0; i < this.mean.length; ++i) {
            this.nmea[i] = this.mean[i] + (cf.centroid(i) - this.mean[i]) * wei / nwsum;
        }
        for (i = 0; i < this.mean.length; ++i) {
            double vi = cf.centroid(i);
            int n = i;
            this.variances[n] = this.variances[n] + (wei * cf.variance(i) + (vi - this.nmea[i]) * (vi - this.mean[i]) * wei);
        }
        this.wsum = nwsum;
        System.arraycopy(this.nmea, 0, this.mean, 0, this.nmea.length);
    }
}

