/*
 * Decompiled with CFR 0.152.
 */
package smile.stat.distribution;

import java.util.ArrayList;
import java.util.List;
import smile.math.Math;
import smile.stat.distribution.DiscreteExponentialFamily;
import smile.stat.distribution.DiscreteMixture;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class DiscreteExponentialFamilyMixture
extends DiscreteMixture {
    DiscreteExponentialFamilyMixture() {
    }

    public DiscreteExponentialFamilyMixture(List<DiscreteMixture.Component> mixture) {
        super(mixture);
        for (DiscreteMixture.Component component : mixture) {
            if (component.distribution instanceof DiscreteExponentialFamily) continue;
            throw new IllegalArgumentException("Component " + component + " is not of discrete exponential family.");
        }
    }

    public DiscreteExponentialFamilyMixture(List<DiscreteMixture.Component> mixture, int[] data) {
        this(mixture);
        this.EM(this.components, data);
    }

    double EM(List<DiscreteMixture.Component> mixture, int[] x) {
        return this.EM(mixture, x, 0.2);
    }

    double EM(List<DiscreteMixture.Component> mixture, int[] x, double gamma) {
        return this.EM(mixture, x, gamma, Integer.MAX_VALUE);
    }

    double EM(List<DiscreteMixture.Component> components, int[] x, double gamma, int maxIter) {
        if (x.length < components.size() / 2) {
            throw new IllegalArgumentException("Too many components");
        }
        if (gamma < 0.0 || gamma > 0.2) {
            throw new IllegalArgumentException("Invalid regularization factor gamma.");
        }
        if (maxIter <= 0) {
            maxIter = Integer.MAX_VALUE;
        }
        int n = x.length;
        int m = components.size();
        double[][] posteriori = new double[m][n];
        double L = 0.0;
        int[] arr$ = x;
        int len$ = arr$.length;
        for (int i$ = 0; i$ < len$; ++i$) {
            double xi = arr$[i$];
            double p = 0.0;
            for (DiscreteMixture.Component c : components) {
                p += c.priori * c.distribution.p(xi);
            }
            if (!(p > 0.0)) continue;
            L += Math.log(p);
        }
        for (int iter = 0; iter < maxIter; ++iter) {
            int i;
            for (int i2 = 0; i2 < m; ++i2) {
                DiscreteMixture.Component c = components.get(i2);
                for (int j = 0; j < n; ++j) {
                    posteriori[i2][j] = c.priori * c.distribution.p(x[j]);
                }
            }
            for (int j = 0; j < n; ++j) {
                double p = 0.0;
                for (i = 0; i < m; ++i) {
                    p += posteriori[i][j];
                }
                for (i = 0; i < m; ++i) {
                    double[] dArray = posteriori[i];
                    int n2 = j;
                    dArray[n2] = dArray[n2] / p;
                }
                if (!(gamma > 0.0)) continue;
                for (i = 0; i < m; ++i) {
                    double[] dArray = posteriori[i];
                    int n3 = j;
                    dArray[n3] = dArray[n3] * (1.0 + gamma * Math.log2(posteriori[i][j]));
                    if (!Double.isNaN(posteriori[i][j]) && !(posteriori[i][j] < 0.0)) continue;
                    posteriori[i][j] = 0.0;
                }
            }
            ArrayList<DiscreteMixture.Component> newConfig = new ArrayList<DiscreteMixture.Component>();
            for (int i3 = 0; i3 < m; ++i3) {
                newConfig.add(((DiscreteExponentialFamily)((Object)components.get((int)i3).distribution)).M(x, posteriori[i3]));
            }
            double sumAlpha = 0.0;
            for (i = 0; i < m; ++i) {
                sumAlpha += ((DiscreteMixture.Component)newConfig.get((int)i)).priori;
            }
            for (i = 0; i < m; ++i) {
                ((DiscreteMixture.Component)newConfig.get((int)i)).priori /= sumAlpha;
            }
            double newL = 0.0;
            int[] arr$2 = x;
            int len$2 = arr$2.length;
            for (int i$ = 0; i$ < len$2; ++i$) {
                double xi = arr$2[i$];
                double p = 0.0;
                for (DiscreteMixture.Component c : newConfig) {
                    p += c.priori * c.distribution.p(xi);
                }
                if (!(p > 0.0)) continue;
                newL += Math.log(p);
            }
            if (!(newL > L)) break;
            L = newL;
            components.clear();
            components.addAll(newConfig);
        }
        return L;
    }
}

