/*
 * Decompiled with CFR 0.152.
 */
package smile.stat.distribution;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.stat.distribution.ExponentialFamily;
import smile.stat.distribution.Mixture;

public class ExponentialFamilyMixture
extends Mixture {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(ExponentialFamilyMixture.class);
    public final double L;
    public final double bic;

    public ExponentialFamilyMixture(Mixture.Component ... components) {
        this(0.0, 1, components);
    }

    ExponentialFamilyMixture(double L, int n, Mixture.Component ... components) {
        super(components);
        for (Mixture.Component component : components) {
            if (component.distribution() instanceof ExponentialFamily) continue;
            throw new IllegalArgumentException("Component " + String.valueOf(component) + " is not of exponential family.");
        }
        this.L = L;
        this.bic = L - 0.5 * (double)this.length() * Math.log(n);
    }

    public static ExponentialFamilyMixture fit(double[] x, Mixture.Component ... components) {
        return ExponentialFamilyMixture.fit(x, components, 0.0, 500, 1.0E-4);
    }

    public static ExponentialFamilyMixture fit(double[] x, Mixture.Component[] components, double gamma, int maxIter, double tol) {
        if (x.length < components.length / 2) {
            throw new IllegalArgumentException("Too many components");
        }
        if (gamma < 0.0 || gamma > 0.2) {
            throw new IllegalArgumentException("Invalid regularization factor gamma.");
        }
        int n = x.length;
        int k = components.length;
        double[][] posteriori = new double[k][n];
        double L = 0.0;
        double diff = Double.MAX_VALUE;
        for (int iter = 1; iter <= maxIter && diff > tol; ++iter) {
            int i;
            for (int i2 = 0; i2 < k; ++i2) {
                Mixture.Component c = components[i2];
                for (int j = 0; j < n; ++j) {
                    posteriori[i2][j] = c.priori() * c.distribution().p(x[j]);
                }
            }
            for (int j = 0; j < n; ++j) {
                int i3;
                double p = 0.0;
                for (i3 = 0; i3 < k; ++i3) {
                    p += posteriori[i3][j];
                }
                for (i3 = 0; i3 < k; ++i3) {
                    double[] dArray = posteriori[i3];
                    int n2 = j;
                    dArray[n2] = dArray[n2] / p;
                }
                if (!(gamma > 0.0)) continue;
                for (i3 = 0; i3 < k; ++i3) {
                    double[] dArray = posteriori[i3];
                    int n3 = j;
                    dArray[n3] = dArray[n3] * (1.0 + gamma * MathEx.log2(posteriori[i3][j]));
                    if (!Double.isNaN(posteriori[i3][j]) && !(posteriori[i3][j] < 0.0)) continue;
                    posteriori[i3][j] = 0.0;
                }
            }
            double Z = 0.0;
            for (i = 0; i < k; ++i) {
                components[i] = ((ExponentialFamily)components[i].distribution()).M(x, posteriori[i]);
                Z += components[i].priori();
            }
            for (i = 0; i < k; ++i) {
                components[i] = new Mixture.Component(components[i].priori() / Z, components[i].distribution());
            }
            double loglikelihood = 0.0;
            for (double xi : x) {
                double p = 0.0;
                for (Mixture.Component c : components) {
                    p += c.priori() * c.distribution().p(xi);
                }
                if (!(p > 0.0)) continue;
                loglikelihood += Math.log(p);
            }
            diff = loglikelihood - L;
            L = loglikelihood;
            if (iter % 10 != 0) continue;
            logger.info("The log-likelihood after {} iterations: {}", (Object)iter, (Object)L);
        }
        return new ExponentialFamilyMixture(L, x.length, components);
    }
}

