/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.em;

import elki.clustering.ClusteringAlgorithm;
import elki.clustering.em.models.EMClusterModel;
import elki.clustering.em.models.EMClusterModelFactory;
import elki.clustering.em.models.MultivariateGaussianModelFactory;
import elki.data.Cluster;
import elki.data.Clustering;
import elki.data.model.MeanModel;
import elki.data.type.SimpleTypeInformation;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.datastore.DataStore;
import elki.database.datastore.DataStoreUtil;
import elki.database.datastore.WritableDataStore;
import elki.database.datastore.WritableDoubleDataStore;
import elki.database.ids.ArrayModifiableDBIDs;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.ids.ModifiableDBIDs;
import elki.database.relation.MaterializedRelation;
import elki.database.relation.Relation;
import elki.logging.Logging;
import elki.logging.statistics.DoubleStatistic;
import elki.logging.statistics.LongStatistic;
import elki.logging.statistics.Statistic;
import elki.math.linearalgebra.VMath;
import elki.result.Metadata;
import elki.utilities.Priority;
import elki.utilities.documentation.Description;
import elki.utilities.documentation.Reference;
import elki.utilities.documentation.References;
import elki.utilities.documentation.Title;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.constraints.ParameterConstraint;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.optionhandling.parameters.Flag;
import elki.utilities.optionhandling.parameters.IntParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import java.util.ArrayList;
import java.util.List;
import net.jafama.FastMath;

@Title(value="EM-Clustering: Clustering by Expectation Maximization")
@Description(value="Cluster data via Gaussian mixture modeling and the EM algorithm")
@References(value={@Reference(authors="A. P. Dempster, N. M. Laird, D. B. Rubin", title="Maximum Likelihood from Incomplete Data via the EM algorithm", booktitle="Journal of the Royal Statistical Society, Series B, 39(1)", url="http://www.jstor.org/stable/2984875", bibkey="journals/jroyastatsocise2/DempsterLR77"), @Reference(title="Bayesian Regularization for Normal Mixture Estimation and Model-Based Clustering", authors="C. Fraley, A. E. Raftery", booktitle="J. Classification 24(2)", url="https://doi.org/10.1007/s00357-007-0004-5", bibkey="DBLP:journals/classification/FraleyR07")})
@Priority(value=200)
public class EM<O, M extends MeanModel>
implements ClusteringAlgorithm<Clustering<M>> {
    private static final Logging LOG = Logging.getLogger(EM.class);
    private static final String KEY = EM.class.getName();
    protected int k;
    protected double delta;
    protected EMClusterModelFactory<? super O, M> mfactory;
    protected int miniter;
    protected int maxiter;
    protected double prior = 0.0;
    protected boolean soft;
    protected static final double MIN_LOGLIKELIHOOD = -100000.0;
    public static final SimpleTypeInformation<double[]> SOFT_TYPE = new SimpleTypeInformation(double[].class);

    public EM(int k, double delta, EMClusterModelFactory<? super O, M> mfactory) {
        this(k, delta, mfactory, -1, 0.0, false);
    }

    public EM(int k, double delta, EMClusterModelFactory<? super O, M> mfactory, int maxiter, boolean soft) {
        this(k, delta, mfactory, maxiter, 0.0, soft);
    }

    public EM(int k, double delta, EMClusterModelFactory<? super O, M> mfactory, int maxiter, double prior, boolean soft) {
        this(k, delta, mfactory, 1, maxiter, prior, soft);
    }

    public EM(int k, double delta, EMClusterModelFactory<? super O, M> mfactory, int miniter, int maxiter, double prior, boolean soft) {
        this.k = k;
        this.delta = delta;
        this.mfactory = mfactory;
        this.miniter = miniter;
        this.maxiter = maxiter;
        this.prior = prior;
        this.soft = soft;
    }

    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array((TypeInformation[])new TypeInformation[]{TypeUtil.NUMBER_VECTOR_FIELD});
    }

    public Clustering<M> run(Relation<O> relation) {
        if (relation.size() == 0) {
            throw new IllegalArgumentException("database empty: must contain elements");
        }
        List<EMClusterModel<O, M>> models = this.mfactory.buildInitialModels(relation, this.k);
        WritableDataStore probClusterIGivenX = DataStoreUtil.makeStorage((DBIDs)relation.getDBIDs(), (int)10, double[].class);
        double loglikelihood = EM.assignProbabilitiesToInstances(relation, models, (WritableDataStore<double[]>)probClusterIGivenX, null);
        DoubleStatistic likestat = new DoubleStatistic(this.getClass().getName() + ".loglikelihood");
        LOG.statistics((Statistic)likestat.setDouble(loglikelihood));
        int it = 0;
        int lastimprovement = 0;
        double bestloglikelihood = Double.NEGATIVE_INFINITY;
        ++it;
        while (it < this.maxiter || this.maxiter < 0) {
            double oldloglikelihood = loglikelihood;
            EM.recomputeCovarianceMatrices(relation, (WritableDataStore<double[]>)probClusterIGivenX, models, this.prior);
            loglikelihood = EM.assignProbabilitiesToInstances(relation, models, (WritableDataStore<double[]>)probClusterIGivenX, null);
            LOG.statistics((Statistic)likestat.setDouble(loglikelihood));
            if (loglikelihood - bestloglikelihood > this.delta) {
                lastimprovement = it;
                bestloglikelihood = loglikelihood;
            }
            if (it >= this.miniter && (Math.abs(loglikelihood - oldloglikelihood) <= this.delta || lastimprovement < it >> 1)) break;
            ++it;
        }
        LOG.statistics((Statistic)new LongStatistic(KEY + ".iterations", (long)it));
        ArrayList<ArrayModifiableDBIDs> hardClusters = new ArrayList<ArrayModifiableDBIDs>(this.k);
        for (int i = 0; i < this.k; ++i) {
            hardClusters.add(DBIDUtil.newArray());
        }
        DBIDIter iditer = relation.iterDBIDs();
        while (iditer.valid()) {
            ((ModifiableDBIDs)hardClusters.get(VMath.argmax((double[])((double[])probClusterIGivenX.get((DBIDRef)iditer))))).add((DBIDRef)iditer);
            iditer.advance();
        }
        Clustering<MeanModel> result = new Clustering<MeanModel>();
        Metadata.of(result).setLongName("EM Clustering");
        for (int i = 0; i < this.k; ++i) {
            result.addToplevelCluster(new Cluster<MeanModel>((DBIDs)hardClusters.get(i), (MeanModel)models.get(i).finalizeCluster()));
        }
        if (this.soft) {
            Metadata.hierarchyOf(result).addChild((Object)new MaterializedRelation("EM Cluster Probabilities", SOFT_TYPE, relation.getDBIDs(), (DataStore)probClusterIGivenX));
        } else {
            probClusterIGivenX.destroy();
        }
        return result;
    }

    /*
     * WARNING - void declaration
     */
    public static <O> void recomputeCovarianceMatrices(Relation<? extends O> relation, WritableDataStore<double[]> probClusterIGivenX, List<? extends EMClusterModel<? super O, ?>> models, double prior) {
        void var8_14;
        int k = models.size();
        boolean needsTwoPass = false;
        for (EMClusterModel<O, ?> eMClusterModel : models) {
            eMClusterModel.beginEStep();
            needsTwoPass |= eMClusterModel.needsTwoPass();
        }
        if (needsTwoPass) {
            DBIDIter iditer = relation.iterDBIDs();
            while (iditer.valid()) {
                double[] dArray = (double[])probClusterIGivenX.get((DBIDRef)iditer);
                Object instance = relation.get((DBIDRef)iditer);
                for (int i = 0; i < dArray.length; ++i) {
                    double prob = dArray[i];
                    if (!(prob > 1.0E-10)) continue;
                    models.get(i).firstPassE(instance, prob);
                }
                iditer.advance();
            }
            for (EMClusterModel eMClusterModel : models) {
                eMClusterModel.finalizeFirstPassE();
            }
        }
        double[] wsum = new double[k];
        DBIDIter dBIDIter = relation.iterDBIDs();
        while (dBIDIter.valid()) {
            double[] clusterProbabilities = (double[])probClusterIGivenX.get((DBIDRef)dBIDIter);
            Object instance = relation.get((DBIDRef)dBIDIter);
            int i = 0;
            while (i < clusterProbabilities.length) {
                double prob = clusterProbabilities[i];
                if (prob > 1.0E-10) {
                    models.get(i).updateE(instance, prob);
                }
                int n = i++;
                wsum[n] = wsum[n] + prob;
            }
            dBIDIter.advance();
        }
        boolean bl = false;
        while (var8_14 < models.size()) {
            double weight = prior <= 0.0 ? wsum[var8_14] / (double)relation.size() : (wsum[var8_14] + prior - 1.0) / ((double)relation.size() + prior * (double)k - (double)k);
            models.get((int)var8_14).finalizeEStep(weight, prior);
            ++var8_14;
        }
    }

    public static <O> double assignProbabilitiesToInstances(Relation<? extends O> relation, List<? extends EMClusterModel<? super O, ?>> models, WritableDataStore<double[]> probClusterIGivenX, WritableDoubleDataStore loglikelihoods) {
        int k = models.size();
        double emSum = 0.0;
        DBIDIter iditer = relation.iterDBIDs();
        while (iditer.valid()) {
            Object vec = relation.get((DBIDRef)iditer);
            double[] probs = new double[k];
            for (int i = 0; i < k; ++i) {
                double v = models.get(i).estimateLogDensity(vec);
                probs[i] = v > -100000.0 ? v : -100000.0;
            }
            double logP = EM.logSumExp(probs);
            for (int i = 0; i < k; ++i) {
                probs[i] = FastMath.exp((double)(probs[i] - logP));
            }
            probClusterIGivenX.put((DBIDRef)iditer, (Object)probs);
            if (loglikelihoods != null) {
                loglikelihoods.put((DBIDRef)iditer, logP);
            }
            emSum += logP;
            iditer.advance();
        }
        return emSum / (double)relation.size();
    }

    public static double logSumExp(double[] x) {
        double max = x[0];
        for (int i = 1; i < x.length; ++i) {
            double v = x[i];
            max = v > max ? v : max;
        }
        double cutoff = max - 35.350506209;
        double acc = 0.0;
        for (int i = 0; i < x.length; ++i) {
            double v = x[i];
            if (!(v > cutoff)) continue;
            acc += v < max ? FastMath.exp((double)(v - max)) : 1.0;
        }
        return acc > 1.0 ? max + FastMath.log((double)acc) : max;
    }

    protected static double logSumExp(double a, double b) {
        return (a > b ? a : b) + FastMath.log((double)(a > b ? FastMath.exp((double)(b - a)) + 1.0 : FastMath.exp((double)(a - b)) + 1.0));
    }

    public static class Par<O, M extends MeanModel>
    implements Parameterizer {
        public static final OptionID K_ID = new OptionID("em.k", "The number of clusters to find.");
        public static final OptionID DELTA_ID = new OptionID("em.delta", "The termination criterion for maximization of E(M): E(M) - E(M') < em.delta");
        public static final OptionID MODEL_ID = new OptionID("em.model", "Model factory.");
        public static final OptionID MINITER_ID = new OptionID("em.miniter", "Minimum number of iterations.");
        public static final OptionID MAXITER_ID = new OptionID("em.maxiter", "Maximum number of iterations.");
        public static final OptionID PRIOR_ID = new OptionID("em.map.prior", "Regularization factor for MAP estimation.");
        public static final OptionID SOFT_ID = new OptionID("em.soft", "Retain soft assignment of clusters.");
        protected int k;
        protected double delta;
        protected EMClusterModelFactory<O, M> mfactory;
        protected int miniter = 1;
        protected int maxiter = -1;
        double prior = 0.0;
        boolean soft = false;

        public void configure(Parameterization config) {
            ((IntParameter)new IntParameter(K_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_INT)).grab(config, x -> {
                this.k = x;
            });
            new ObjectParameter(MODEL_ID, EMClusterModelFactory.class, MultivariateGaussianModelFactory.class).grab(config, x -> {
                this.mfactory = x;
            });
            ((DoubleParameter)new DoubleParameter(DELTA_ID, 1.0E-7).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_DOUBLE)).grab(config, x -> {
                this.delta = x;
            });
            ((IntParameter)((IntParameter)new IntParameter(MINITER_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_INT)).setOptional(true)).grab(config, x -> {
                this.miniter = x;
            });
            ((IntParameter)((IntParameter)new IntParameter(MAXITER_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_INT)).setOptional(true)).grab(config, x -> {
                this.maxiter = x;
            });
            ((DoubleParameter)((DoubleParameter)new DoubleParameter(PRIOR_ID).setOptional(true)).addConstraint((ParameterConstraint)CommonConstraints.GREATER_THAN_ZERO_DOUBLE)).grab(config, x -> {
                this.prior = x;
            });
            new Flag(SOFT_ID).grab(config, x -> {
                this.soft = x;
            });
        }

        public EM<O, M> make() {
            return new EM<O, M>(this.k, this.delta, this.mfactory, this.miniter, this.maxiter, this.prior, this.soft);
        }
    }
}

