/*
 * Decompiled with CFR 0.152.
 */
package smile.stat.distribution;

import java.io.Serializable;
import java.util.Arrays;
import java.util.stream.Collectors;
import smile.math.MathEx;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.Matrix;
import smile.stat.distribution.MultivariateDistribution;

public class MultivariateMixture
implements MultivariateDistribution {
    private static final long serialVersionUID = 2L;
    public final Component[] components;

    public MultivariateMixture(Component ... components) {
        if (components.length == 0) {
            throw new IllegalStateException("Empty mixture!");
        }
        this.components = components;
    }

    public double[] posteriori(double[] x) {
        int k = this.components.length;
        double[] prob = new double[k];
        for (int i = 0; i < k; ++i) {
            Component c = this.components[i];
            prob[i] = c.priori * c.distribution.p(x);
        }
        double p = MathEx.sum(prob);
        int i = 0;
        while (i < k) {
            int n = i++;
            prob[n] = prob[n] / p;
        }
        return prob;
    }

    public int map(double[] x) {
        int k = this.components.length;
        double[] prob = new double[k];
        for (int i = 0; i < k; ++i) {
            Component c = this.components[i];
            prob[i] = c.priori * c.distribution.p(x);
        }
        return MathEx.whichMax(prob);
    }

    @Override
    public double[] mean() {
        double w = this.components[0].priori;
        double[] m = this.components[0].distribution.mean();
        double[] mu = new double[m.length];
        for (int i = 0; i < m.length; ++i) {
            mu[i] = w * m[i];
        }
        for (int k = 1; k < this.components.length; ++k) {
            w = this.components[k].priori;
            m = this.components[k].distribution.mean();
            for (int i = 0; i < m.length; ++i) {
                int n = i;
                mu[n] = mu[n] + w * m[i];
            }
        }
        return mu;
    }

    @Override
    public DenseMatrix cov() {
        double w = this.components[0].priori;
        DenseMatrix v = this.components[0].distribution.cov();
        int m = v.nrows();
        int n = v.ncols();
        DenseMatrix cov = Matrix.zeros(m, n);
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                cov.set(i, j, w * w * v.get(i, j));
            }
        }
        for (int k = 1; k < this.components.length; ++k) {
            w = this.components[k].priori;
            v = this.components[k].distribution.cov();
            for (int i = 0; i < m; ++i) {
                for (int j = 0; j < n; ++j) {
                    cov.add(i, j, w * w * v.get(i, j));
                }
            }
        }
        return cov;
    }

    @Override
    public double entropy() {
        throw new UnsupportedOperationException("Mixture does not support entropy()");
    }

    @Override
    public double p(double[] x) {
        double p = 0.0;
        for (Component c : this.components) {
            p += c.priori * c.distribution.p(x);
        }
        return p;
    }

    @Override
    public double logp(double[] x) {
        return Math.log(this.p(x));
    }

    @Override
    public double cdf(double[] x) {
        double p = 0.0;
        for (Component c : this.components) {
            p += c.priori * c.distribution.cdf(x);
        }
        return p;
    }

    @Override
    public int length() {
        int f = this.components.length - 1;
        for (Component component : this.components) {
            f += component.distribution.length();
        }
        return f;
    }

    public int size() {
        return this.components.length;
    }

    public double bic(double[][] data) {
        int n = data.length;
        double logLikelihood = 0.0;
        for (double[] x : data) {
            double p = this.p(x);
            if (!(p > 0.0)) continue;
            logLikelihood += Math.log(p);
        }
        return logLikelihood - 0.5 * (double)this.length() * Math.log(n);
    }

    public String toString() {
        return Arrays.stream(this.components).map(component -> String.format("%.2f x %s", component.priori, component.distribution)).collect(Collectors.joining(" + ", String.format("MultivariateMixture(%d)[", this.components.length), "]"));
    }

    public static class Component
    implements Serializable {
        private static final long serialVersionUID = 2L;
        public final double priori;
        public final MultivariateDistribution distribution;

        public Component(double priori, MultivariateDistribution distribution) {
            this.priori = priori;
            this.distribution = distribution;
        }
    }
}

