/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.em;

import elki.clustering.ClusteringAlgorithm;
import elki.clustering.em.BetulaGMMWeighted;
import elki.clustering.em.EM;
import elki.clustering.em.models.BetulaClusterModel;
import elki.clustering.em.models.BetulaClusterModelFactory;
import elki.clustering.kmeans.AbstractKMeans;
import elki.clustering.kmeans.KMeans;
import elki.data.Cluster;
import elki.data.Clustering;
import elki.data.NumberVector;
import elki.data.model.EMModel;
import elki.data.type.SimpleTypeInformation;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.datastore.DataStore;
import elki.database.datastore.DataStoreUtil;
import elki.database.datastore.WritableDataStore;
import elki.database.ids.ArrayModifiableDBIDs;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.ids.ModifiableDBIDs;
import elki.database.relation.MaterializedRelation;
import elki.database.relation.Relation;
import elki.index.tree.betula.CFTree;
import elki.index.tree.betula.features.ClusterFeature;
import elki.logging.Logging;
import elki.logging.statistics.DoubleStatistic;
import elki.logging.statistics.Duration;
import elki.logging.statistics.LongStatistic;
import elki.logging.statistics.Statistic;
import elki.math.linearalgebra.VMath;
import elki.result.Metadata;
import elki.utilities.documentation.Reference;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.constraints.ParameterConstraint;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.optionhandling.parameters.IntParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import net.jafama.FastMath;

@Reference(authors="Andreas Lang and Erich Schubert", title="BETULA: Fast Clustering of Large Data with Improved BIRCH CF-Trees", booktitle="Information Systems", url="https://doi.org/10.1016/j.is.2021.101918", bibkey="DBLP:journals/is/LangS22")
public class BetulaGMM
implements ClusteringAlgorithm<Clustering<EMModel>> {
    private static final Logging LOG = Logging.getLogger(BetulaGMMWeighted.class);
    CFTree.Factory<?> cffactory;
    int k;
    private double delta;
    int maxiter;
    private double prior = 0.0;
    private boolean soft;
    protected static final double MIN_LOGLIKELIHOOD = -100000.0;
    public static final SimpleTypeInformation<double[]> SOFT_TYPE = new SimpleTypeInformation(double[].class);
    BetulaClusterModelFactory<?> initializer;

    public BetulaGMM(CFTree.Factory<?> cffactory, double delta, int k, int maxiter, boolean soft, BetulaClusterModelFactory<?> initialization, double prior) {
        this.cffactory = cffactory;
        this.delta = delta;
        this.k = k;
        this.maxiter = maxiter;
        this.soft = soft;
        this.initializer = initialization;
        this.prior = prior;
    }

    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array((TypeInformation[])new TypeInformation[]{TypeUtil.NUMBER_VECTOR_FIELD});
    }

    public Clustering<EMModel> run(Relation<NumberVector> relation) {
        if (relation.size() == 0) {
            throw new IllegalArgumentException("database empty: must contain elements");
        }
        CFTree<?> tree = this.cffactory.newTree(relation.getDBIDs(), relation, false);
        Duration modeltime = LOG.newDuration(this.getClass().getName() + ".modeltime").begin();
        ArrayList<?> cfs = tree.getLeaves();
        List<?> models = this.initializer.buildInitialModels(cfs, this.k, tree);
        Reference2ObjectOpenHashMap probClusterIGivenX = new Reference2ObjectOpenHashMap(cfs.size());
        double loglikelihood = this.assignProbabilitiesToInstances(cfs, models, (Map<ClusterFeature, double[]>)probClusterIGivenX);
        DoubleStatistic likestat = new DoubleStatistic(this.getClass().getName() + ".modelloglikelihood");
        LOG.statistics((Statistic)likestat.setDouble(loglikelihood));
        int it = 0;
        int lastimprovement = 0;
        double bestloglikelihood = Double.NEGATIVE_INFINITY;
        ++it;
        while (it < this.maxiter || this.maxiter < 0) {
            double oldloglikelihood = loglikelihood;
            this.recomputeCovarianceMatrices(cfs, (Map<ClusterFeature, double[]>)probClusterIGivenX, models, this.prior, tree.getRoot().getCF().getWeight());
            loglikelihood = this.assignProbabilitiesToInstances(cfs, models, (Map<ClusterFeature, double[]>)probClusterIGivenX);
            LOG.statistics((Statistic)likestat.setDouble(loglikelihood));
            if (loglikelihood - bestloglikelihood > this.delta) {
                lastimprovement = it;
                bestloglikelihood = loglikelihood;
            }
            if (Math.abs(loglikelihood - oldloglikelihood) <= this.delta || lastimprovement < it >> 1) break;
            ++it;
        }
        LOG.statistics((Statistic)new LongStatistic(this.getClass().getName() + ".iterations", (long)it));
        LOG.statistics((Statistic)modeltime.end());
        ArrayList<ArrayModifiableDBIDs> hardClusters = new ArrayList<ArrayModifiableDBIDs>(this.k);
        for (int i = 0; i < this.k; ++i) {
            hardClusters.add(DBIDUtil.newArray());
        }
        WritableDataStore finalClusterIGivenX = DataStoreUtil.makeStorage((DBIDs)relation.getDBIDs(), (int)10, double[].class);
        loglikelihood = this.assignProbabilitiesToInstances(relation, models, (WritableDataStore<double[]>)finalClusterIGivenX);
        LOG.statistics((Statistic)new DoubleStatistic(this.getClass().getName() + ".loglikelihood", loglikelihood));
        DBIDIter iditer = relation.iterDBIDs();
        while (iditer.valid()) {
            ((ModifiableDBIDs)hardClusters.get(VMath.argmax((double[])((double[])finalClusterIGivenX.get((DBIDRef)iditer))))).add((DBIDRef)iditer);
            iditer.advance();
        }
        Clustering<EMModel> result = new Clustering<EMModel>();
        Metadata.of(result).setLongName("EM Clustering");
        for (int i = 0; i < this.k; ++i) {
            result.addToplevelCluster(new Cluster<EMModel>((DBIDs)hardClusters.get(i), (EMModel)((BetulaClusterModel)models.get(i)).finalizeCluster()));
        }
        if (this.isSoft()) {
            Metadata.hierarchyOf(result).addChild((Object)new MaterializedRelation("EM Cluster Probabilities", SOFT_TYPE, relation.getDBIDs(), (DataStore)finalClusterIGivenX));
        }
        return result;
    }

    private boolean isSoft() {
        return this.soft;
    }

    public double assignProbabilitiesToInstances(ArrayList<? extends ClusterFeature> cfs, List<? extends BetulaClusterModel> models, Map<ClusterFeature, double[]> probClusterIGivenX) {
        int k = models.size();
        double emSum = 0.0;
        int n = 0;
        for (int i = 0; i < cfs.size(); ++i) {
            ClusterFeature cfsi = cfs.get(i);
            double[] probs = new double[k];
            for (int j = 0; j < k; ++j) {
                double v = models.get(j).estimateLogDensity(cfsi);
                probs[j] = v > -100000.0 ? v : -100000.0;
            }
            double logP = EM.logSumExp(probs);
            for (int j = 0; j < k; ++j) {
                probs[j] = FastMath.exp((double)(probs[j] - logP));
            }
            probClusterIGivenX.put(cfsi, probs);
            emSum += logP * (double)cfsi.getWeight();
            n += cfsi.getWeight();
        }
        return emSum / (double)n;
    }

    public double assignProbabilitiesToInstances(Relation<? extends NumberVector> relation, List<? extends BetulaClusterModel> models, WritableDataStore<double[]> probClusterIGivenX) {
        int k = models.size();
        double emSum = 0.0;
        DBIDIter iditer = relation.iterDBIDs();
        while (iditer.valid()) {
            NumberVector vec = (NumberVector)relation.get((DBIDRef)iditer);
            double[] probs = new double[k];
            for (int i = 0; i < k; ++i) {
                double v = models.get(i).estimateLogDensity(vec);
                probs[i] = v > -100000.0 ? v : -100000.0;
            }
            double logP = EM.logSumExp(probs);
            for (int i = 0; i < k; ++i) {
                probs[i] = FastMath.exp((double)(probs[i] - logP));
            }
            probClusterIGivenX.put((DBIDRef)iditer, (Object)probs);
            emSum += logP;
            iditer.advance();
        }
        return emSum / (double)relation.size();
    }

    /*
     * WARNING - void declaration
     */
    public void recomputeCovarianceMatrices(ArrayList<? extends ClusterFeature> cfs, Map<ClusterFeature, double[]> probClusterIGivenX, List<? extends BetulaClusterModel> models, double prior, int n) {
        void var10_13;
        void var10_11;
        int k = models.size();
        boolean needsTwoPass = false;
        for (BetulaClusterModel betulaClusterModel : models) {
            betulaClusterModel.beginEStep();
            needsTwoPass |= betulaClusterModel.needsTwoPass();
        }
        if (needsTwoPass) {
            throw new IllegalStateException("Not Implemented");
        }
        double[] wsum = new double[k];
        boolean bl = false;
        while (var10_11 < cfs.size()) {
            ClusterFeature cfsi = cfs.get((int)var10_11);
            double[] clusterProbabilities = probClusterIGivenX.get(cfsi);
            int j = 0;
            while (j < clusterProbabilities.length) {
                double prob = clusterProbabilities[j];
                if (prob > 1.0E-10) {
                    models.get(j).updateE(cfsi, prob * (double)cfsi.getWeight());
                }
                int n2 = j++;
                wsum[n2] = wsum[n2] + prob * (double)cfsi.getWeight();
            }
            ++var10_11;
        }
        boolean bl2 = false;
        while (var10_13 < models.size()) {
            double weight = prior <= 0.0 ? wsum[var10_13] / (double)n : (wsum[var10_13] + prior - 1.0) / ((double)n + prior * (double)k - (double)k);
            models.get((int)var10_13).finalizeEStep(weight, prior);
            ++var10_13;
        }
    }

    public static class Par
    implements Parameterizer {
        public static final OptionID INIT_ID = new OptionID("em.model", "Model factory.");
        public static final OptionID DELTA_ID = new OptionID("em.delta", "The termination criterion for maximization of E(M): E(M) - E(M') < em.delta");
        public static final OptionID PRIOR_ID = new OptionID("em.map.prior", "Regularization factor for MAP estimation.");
        CFTree.Factory<?> cffactory;
        protected int k;
        protected int maxiter = -1;
        protected double delta;
        protected boolean soft;
        protected double prior = 0.0;
        protected BetulaClusterModelFactory<?> initialization;

        public void configure(Parameterization config) {
            this.cffactory = (CFTree.Factory)config.tryInstantiate(CFTree.Factory.class);
            new ObjectParameter(INIT_ID, BetulaClusterModelFactory.class).grab(config, x -> {
                this.initialization = x;
            });
            ((IntParameter)new IntParameter(AbstractKMeans.K_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_INT)).grab(config, x -> {
                this.k = x;
            });
            ((DoubleParameter)new DoubleParameter(DELTA_ID, 1.0E-7).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_DOUBLE)).grab(config, x -> {
                this.delta = x;
            });
            ((DoubleParameter)((DoubleParameter)new DoubleParameter(PRIOR_ID).setOptional(true)).addConstraint((ParameterConstraint)CommonConstraints.GREATER_THAN_ZERO_DOUBLE)).grab(config, x -> {
                this.prior = x;
            });
            ((IntParameter)((IntParameter)new IntParameter(KMeans.MAXITER_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_INT)).setOptional(true)).grab(config, x -> {
                this.maxiter = x;
            });
        }

        public BetulaGMM make() {
            return new BetulaGMM(this.cffactory, this.delta, this.k, this.maxiter, this.soft, this.initialization, this.prior);
        }
    }
}

