/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.kmeans;

import elki.clustering.kmeans.AbstractKMeans;
import elki.clustering.kmeans.initialization.betula.AbstractCFKMeansInitialization;
import elki.clustering.kmeans.initialization.betula.CFKPlusPlusLeaves;
import elki.data.Cluster;
import elki.data.Clustering;
import elki.data.NumberVector;
import elki.data.model.KMeansModel;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.ids.ModifiableDBIDs;
import elki.database.relation.Relation;
import elki.index.tree.betula.CFTree;
import elki.index.tree.betula.features.ClusterFeature;
import elki.logging.Logging;
import elki.logging.statistics.DoubleStatistic;
import elki.logging.statistics.Duration;
import elki.logging.statistics.LongStatistic;
import elki.logging.statistics.Statistic;
import elki.math.linearalgebra.VMath;
import elki.result.Metadata;
import elki.utilities.documentation.Reference;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.Flag;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import java.util.ArrayList;
import java.util.Arrays;

@Reference(authors="Andreas Lang and Erich Schubert", title="BETULA: Fast Clustering of Large Data with Improved BIRCH CF-Trees", booktitle="Information Systems", url="https://doi.org/10.1016/j.is.2021.101918", bibkey="DBLP:journals/is/LangS22")
public class BetulaLloydKMeans
extends AbstractKMeans<NumberVector, KMeansModel> {
    private static final Logging LOG = Logging.getLogger(BetulaLloydKMeans.class);
    CFTree.Factory<?> cffactory;
    AbstractCFKMeansInitialization initialization;
    boolean storeIds = false;
    boolean ignoreWeight = false;
    long diststat = 0L;

    public BetulaLloydKMeans(int k, int maxiter, CFTree.Factory<?> cffactory, AbstractCFKMeansInitialization initialization, boolean storeIds, boolean ignoreWeight) {
        super(k, maxiter, null);
        this.cffactory = cffactory;
        this.initialization = initialization;
        this.storeIds = storeIds;
        this.ignoreWeight = ignoreWeight;
    }

    @Override
    public Clustering<KMeansModel> run(Relation<NumberVector> relation) {
        CFTree<?> tree = this.cffactory.newTree(relation.getDBIDs(), relation, this.storeIds);
        ArrayList<?> cfs = tree.getLeaves();
        Duration modeltime = LOG.newDuration(this.getClass().getName() + ".modeltime").begin();
        int[] assignment = new int[cfs.size()];
        int[] weights = new int[this.k];
        Arrays.fill(assignment, -1);
        double[][] means = this.kmeans(cfs, assignment, weights, tree);
        LOG.statistics((Statistic)modeltime.end());
        ModifiableDBIDs[] ids = new ModifiableDBIDs[this.k];
        for (int i = 0; i < this.k; ++i) {
            ids[i] = DBIDUtil.newArray((int)weights[i]);
        }
        double[] varsum = new double[this.k];
        if (this.storeIds) {
            for (int i = 0; i < assignment.length; ++i) {
                ClusterFeature cfsi = (ClusterFeature)cfs.get(i);
                double[] mean = means[assignment[i]];
                double s = cfsi.sumdev();
                for (int d = 0; d < means[0].length; ++d) {
                    double dx = cfsi.centroid(d) - mean[d];
                    s += (double)cfsi.getWeight() * dx * dx;
                }
                int n = assignment[i];
                varsum[n] = varsum[n] + s;
                ids[assignment[i]].addDBIDs(tree.getDBIDs(cfsi));
            }
        } else {
            DBIDIter iter = relation.iterDBIDs();
            while (iter.valid()) {
                NumberVector fv = (NumberVector)relation.get((DBIDRef)iter);
                double mindist = this.distance(fv, means[0]);
                int minIndex = 0;
                for (int i = 1; i < this.k; ++i) {
                    double dist = this.distance(fv, means[i]);
                    if (!(dist < mindist)) continue;
                    minIndex = i;
                    mindist = dist;
                }
                int n = minIndex;
                varsum[n] = varsum[n] + mindist;
                ids[minIndex].add((DBIDRef)iter);
                iter.advance();
            }
        }
        LOG.statistics((Statistic)new LongStatistic(this.getClass().getName() + ".distance-computations", this.diststat));
        LOG.statistics((Statistic)new DoubleStatistic(this.getClass().getName() + ".variance-sum", VMath.sum((double[])varsum)));
        Clustering<KMeansModel> result = new Clustering<KMeansModel>();
        for (int i = 0; i < ids.length; ++i) {
            KMeansModel model = new KMeansModel(means[i], varsum[i]);
            result.addToplevelCluster(new Cluster<KMeansModel>((DBIDs)ids[i], model));
        }
        Metadata.of(result).setLongName("BIRCH k-Means Clustering");
        return result;
    }

    private double[][] kmeans(ArrayList<? extends ClusterFeature> cfs, int[] assignment, int[] weights, CFTree<?> tree) {
        double[][] means = this.initialization.chooseInitialMeans(tree, cfs, this.k);
        for (int i = 1; i <= this.maxiter || this.maxiter < 0; ++i) {
            long prevdiststat = this.diststat;
            double[][] dArray = means = i == 1 ? means : this.means(assignment, means, cfs, weights);
            if (i > 1 && LOG.isStatistics()) {
                double varsum = VMath.sum((double[])this.calculateVariances(assignment, means, cfs, weights));
                LOG.statistics((Statistic)new DoubleStatistic(this.getClass().getName() + "." + (i - 1) + ".variance-sum", varsum));
            }
            int changed = this.assignToNearestCluster(assignment, means, cfs, weights);
            if (LOG.isStatistics()) {
                LOG.statistics((Statistic)new LongStatistic(this.getClass().getName() + "." + i + ".reassigned", (long)changed));
                if (this.diststat > prevdiststat) {
                    LOG.statistics((Statistic)new LongStatistic(this.getClass().getName() + "." + i + ".distance-computations", this.diststat - prevdiststat));
                }
            }
            if (changed == 0) break;
        }
        return means;
    }

    private double[][] means(int[] assignment, double[][] means, ArrayList<? extends ClusterFeature> cfs, int[] weights) {
        int i;
        Arrays.fill(weights, 0);
        double[][] newMeans = new double[this.k][];
        for (i = 0; i < assignment.length; ++i) {
            int j;
            int c = assignment[i];
            ClusterFeature cf = cfs.get(i);
            int d = cf.getDimensionality();
            int n = cf.getWeight();
            if (newMeans[c] == null) {
                newMeans[c] = new double[d];
                for (j = 0; j < d; ++j) {
                    newMeans[c][j] = cf.centroid(j) * (double)n;
                }
            } else {
                for (j = 0; j < d; ++j) {
                    double[] dArray = newMeans[c];
                    int n2 = j;
                    dArray[n2] = dArray[n2] + cf.centroid(j) * (double)n;
                }
            }
            int n3 = c;
            weights[n3] = weights[n3] + n;
        }
        for (i = 0; i < this.k; ++i) {
            if (weights[i] == 0) {
                newMeans[i] = means[i];
                continue;
            }
            VMath.timesEquals((double[])newMeans[i], (double)(1.0 / (double)weights[i]));
        }
        return newMeans;
    }

    private int assignToNearestCluster(int[] assignment, double[][] means, ArrayList<? extends ClusterFeature> cfs, int[] weights) {
        Arrays.fill(weights, 0);
        int changed = 0;
        for (int i = 0; i < cfs.size(); ++i) {
            ClusterFeature cfsi = cfs.get(i);
            double[] mean = new double[cfsi.getDimensionality()];
            for (int j = 0; j < mean.length; ++j) {
                mean[j] = cfsi.centroid(j);
            }
            double mindist = this.distance(mean, means[0]);
            int minIndex = 0;
            for (int j = 1; j < this.k; ++j) {
                double dist = this.distance(mean, means[j]);
                if (!(dist < mindist)) continue;
                minIndex = j;
                mindist = dist;
            }
            if (assignment[i] != minIndex) {
                ++changed;
                assignment[i] = minIndex;
            }
            int n = minIndex;
            weights[n] = weights[n] + (this.ignoreWeight ? 1 : cfsi.getWeight());
        }
        return changed;
    }

    protected double[] calculateVariances(int[] assignment, double[][] means, ArrayList<? extends ClusterFeature> cfs, int[] weights) {
        double[] ss = new double[this.k];
        for (int i = 0; i < assignment.length; ++i) {
            ClusterFeature cfsi = cfs.get(i);
            double[] mean = means[assignment[i]];
            double s = this.ignoreWeight ? cfsi.sumdev() / (double)cfsi.getWeight() : cfsi.sumdev();
            for (int d = 0; d < means[0].length; ++d) {
                double dx = cfsi.centroid(d) - mean[d];
                s += (double)(this.ignoreWeight ? 1 : cfsi.getWeight()) * dx * dx;
            }
            int n = assignment[i];
            ss[n] = ss[n] + s;
        }
        return ss;
    }

    private double distance(NumberVector x, double[] y) {
        ++this.diststat;
        double v = 0.0;
        for (int i = 0; i < y.length; ++i) {
            double d = x.doubleValue(i) - y[i];
            v += d * d;
        }
        return v;
    }

    private double distance(double[] x, double[] y) {
        ++this.diststat;
        double v = 0.0;
        for (int i = 0; i < x.length; ++i) {
            double d = x[i] - y[i];
            v += d * d;
        }
        return v;
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Par
    extends AbstractKMeans.Par<NumberVector> {
        public static final OptionID STORE_IDS_ID = new OptionID("betula.storeids", "Store IDs when building the tree, and use when assigning to leaves.");
        public static final OptionID IGNORE_WEIGHT_ID = new OptionID("betulakm.naive", "Treat leaves as single points, not weighted points.");
        CFTree.Factory<?> cffactory;
        AbstractCFKMeansInitialization initialization;
        boolean storeIds = false;
        boolean ignoreWeight = false;

        @Override
        public void configure(Parameterization config) {
            this.cffactory = (CFTree.Factory)config.tryInstantiate(CFTree.Factory.class);
            super.getParameterK(config);
            super.getParameterMaxIter(config);
            new ObjectParameter(AbstractKMeans.INIT_ID, AbstractCFKMeansInitialization.class, CFKPlusPlusLeaves.class).grab(config, x -> {
                this.initialization = x;
            });
            new Flag(STORE_IDS_ID).grab(config, x -> {
                this.storeIds = x;
            });
            new Flag(IGNORE_WEIGHT_ID).grab(config, x -> {
                this.ignoreWeight = x;
            });
        }

        @Override
        public BetulaLloydKMeans make() {
            return new BetulaLloydKMeans(this.k, this.maxiter, this.cffactory, this.initialization, this.storeIds, this.ignoreWeight);
        }
    }
}

