/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.kmeans;

import elki.clustering.kmeans.AbstractKMeans;
import elki.clustering.kmeans.initialization.KMeansInitialization;
import elki.data.Clustering;
import elki.data.NumberVector;
import elki.data.model.KMeansModel;
import elki.database.datastore.DataStoreUtil;
import elki.database.datastore.WritableDataStore;
import elki.database.datastore.WritableDoubleDataStore;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDs;
import elki.database.ids.ModifiableDBIDs;
import elki.database.relation.Relation;
import elki.distance.NumberVectorDistance;
import elki.logging.Logging;
import elki.logging.statistics.LongStatistic;
import elki.logging.statistics.Statistic;
import elki.math.MathUtil;
import elki.math.linearalgebra.VMath;
import elki.utilities.documentation.Reference;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.constraints.ParameterConstraint;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.IntParameter;
import java.util.Arrays;

@Reference(authors="Y. Ding, Y. Zhao, X. Shen, M, Musuvathi, T. Mytkowicz", title="Yinyang K-Means: A Drop-In Replacement of the Classic K-Means with Consistent Speedup", booktitle="Proc. International Conference on Machine Learning (ICML 2015)", url="http://proceedings.mlr.press/v37/ding15.html", bibkey="DBLP:conf/icml/DingZSMM15")
public class YinYangKMeans<V extends NumberVector>
extends AbstractKMeans<V, KMeansModel> {
    private static final Logging LOG = Logging.getLogger(YinYangKMeans.class);
    private static final int GROUP_KMEANS_MAXITER = 5;
    private int t;

    public YinYangKMeans(int k, int maxiter, KMeansInitialization initializer, int t) {
        super(k, maxiter, initializer);
        this.t = t;
    }

    @Override
    public Clustering<KMeansModel> run(Relation<V> rel) {
        Instance instance = new Instance(rel, this.getDistance(), this.initialMeans(rel), this.t);
        instance.run(this.maxiter);
        return instance.buildResult();
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Par<V extends NumberVector>
    extends AbstractKMeans.Par<V> {
        public static final OptionID T_ID = new OptionID("kmeans.yinyang.t", "The number of groups to use for bounding the centroids.");
        protected int t;

        @Override
        protected boolean needsMetric() {
            return true;
        }

        @Override
        public void configure(Parameterization config) {
            super.configure(config);
            int deft = this.k > 10 ? this.k / 10 : (this.k > 1 ? this.k / 2 : 1);
            ((IntParameter)((IntParameter)new IntParameter(T_ID).setDefaultValue((Object)deft)).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_INT)).grab(config, x -> {
                this.t = x;
            });
        }

        @Override
        public YinYangKMeans<V> make() {
            return new YinYangKMeans(this.k, this.maxiter, this.initializer, this.t);
        }
    }

    protected static class Instance
    extends AbstractKMeans.Instance {
        int[][] groups;
        double[] gdrift;
        double[] cdrift;
        double[][] sums;
        int[] glabel;
        WritableDoubleDataStore upper;
        WritableDataStore<double[]> lower;

        public Instance(Relation<? extends NumberVector> relation, NumberVectorDistance<?> df, double[][] means, int t) {
            super(relation, df, means);
            this.glabel = new int[this.k];
            t = t > 0 ? (t < this.k ? t : this.k) : (this.k >= 10 ? this.k / 10 : this.k / 2);
            this.upper = DataStoreUtil.makeDoubleStorage((DBIDs)relation.getDBIDs(), (int)3, (double)Double.POSITIVE_INFINITY);
            this.lower = DataStoreUtil.makeStorage((DBIDs)relation.getDBIDs(), (int)3, double[].class);
            DBIDIter it = relation.iterDBIDs();
            while (it.valid()) {
                this.lower.put((DBIDRef)it, (Object)new double[t]);
                it.advance();
            }
            int dim = means[0].length;
            this.cdrift = new double[this.k];
            this.sums = new double[this.k][dim];
            this.gdrift = new double[t];
        }

        @Override
        public void run(int maxiter) {
            this.groups = this.groupKMeans(this.gdrift.length);
            super.run(maxiter);
        }

        private int[][] groupKMeans(int t) {
            if (t <= 1) {
                Arrays.fill(this.glabel, 0);
                return new int[][]{MathUtil.sequence((int)0, (int)this.means.length)};
            }
            long before = this.diststat;
            double[][] gmean = new double[t][];
            int[] gweight = new int[t];
            this.initialGroupAssignment(t, gmean, gweight);
            for (int it = 1; it <= 5 && this.updateGroupAssignment(t, gmean, gweight); ++it) {
            }
            int[][] meanGroups = new int[t][];
            for (int i = 0; i < t; ++i) {
                meanGroups[i] = new int[gweight[i]];
                int p = 0;
                for (int j = 0; j < this.k; ++j) {
                    if (this.glabel[j] != i) continue;
                    meanGroups[i][p++] = j;
                }
            }
            if (this.getLogger().isStatistics()) {
                this.getLogger().statistics((Statistic)new LongStatistic(this.key + ".yinyang-grouping.distance-computations", this.diststat - before));
            }
            return meanGroups;
        }

        private void initialGroupAssignment(int t, double[][] scratch, int[] gweight) {
            int i;
            for (i = 0; i < t; ++i) {
                scratch[i] = (double[])this.means[i].clone();
                this.glabel[i] = i;
            }
            Arrays.fill(gweight, 1);
            for (i = t; i < this.k; ++i) {
                double[] cur = this.means[i];
                int best = 0;
                double bestd = this.distance(cur, this.means[0]);
                for (int j = 1; j < t; ++j) {
                    double d = this.distance(cur, this.means[j]);
                    if (!(d < bestd)) continue;
                    bestd = d;
                    best = j;
                }
                VMath.plusEquals((double[])scratch[best], (double[])cur);
                this.glabel[i] = best;
                int n = best;
                gweight[n] = gweight[n] + 1;
            }
            for (i = 0; i < t; ++i) {
                VMath.timesEquals((double[])scratch[i], (double)(1.0 / (double)gweight[i]));
            }
        }

        private boolean updateGroupAssignment(int t, double[][] gmeans, int[] gweight) {
            int i;
            boolean changed = false;
            for (i = 0; i < t; ++i) {
                Arrays.fill(this.sums[i], 0.0);
            }
            Arrays.fill(gweight, 0);
            for (i = 0; i < this.k; ++i) {
                double[] cur = this.means[i];
                int prev = this.glabel[i];
                double bestd = this.distance(cur, gmeans[0]);
                int best = 0;
                for (int j = 1; j < t; ++j) {
                    double d = this.distance(cur, gmeans[j]);
                    if (!(d < bestd) && (d != bestd || j != prev)) continue;
                    best = j;
                    bestd = d;
                }
                VMath.plusEquals((double[])this.sums[best], (double[])cur);
                int n = best;
                gweight[n] = gweight[n] + 1;
                this.glabel[i] = best;
                changed |= best != prev;
            }
            for (i = 0; i < t; ++i) {
                if (gweight[i] <= 0) continue;
                VMath.overwriteTimes((double[])gmeans[i], (double[])this.sums[i], (double)(1.0 / (double)gweight[i]));
            }
            return changed;
        }

        @Override
        protected int iterate(int iteration) {
            if (iteration == 1) {
                return this.initialAssignToNearestCluster();
            }
            this.updateCenters();
            return this.assignToNearestCluster();
        }

        private void updateCenters() {
            int dim = this.means[0].length;
            double[] oldmean = new double[dim];
            for (int g = 0; g < this.groups.length; ++g) {
                double gd = 0.0;
                for (int i : this.groups[g]) {
                    int size = ((ModifiableDBIDs)this.clusters.get(i)).size();
                    if (size <= 0) continue;
                    double[] sum = this.sums[i];
                    double[] mean = this.means[i];
                    System.arraycopy(mean, 0, oldmean, 0, dim);
                    VMath.overwriteTimes((double[])mean, (double[])sum, (double)(1.0 / (double)size));
                    double d = this.cdrift[i] = this.sqrtdistance(mean, oldmean);
                    gd = d > gd ? d : gd;
                }
                this.gdrift[g] = gd;
            }
        }

        @Override
        protected int assignToNearestCluster() {
            int t = this.gdrift.length;
            int changed = 0;
            double[] prevlb = new double[t];
            DBIDIter it = this.relation.iterDBIDs();
            while (it.valid()) {
                NumberVector cur = (NumberVector)this.relation.get((DBIDRef)it);
                int prev = this.assignment.intValue((DBIDRef)it);
                double[] lbs = (double[])this.lower.get((DBIDRef)it);
                System.arraycopy(lbs, 0, prevlb, 0, lbs.length);
                double drift = this.cdrift[prev];
                if (drift > 0.0) {
                    this.upper.increment((DBIDRef)it, drift);
                }
                double minlb = Double.POSITIVE_INFINITY;
                for (int g = 0; g < t; ++g) {
                    int n = g;
                    double d = lbs[n] - this.gdrift[g];
                    lbs[n] = d;
                    double lb = d;
                    minlb = lb < minlb ? lb : minlb;
                }
                double ub = this.upper.doubleValue((DBIDRef)it);
                if (!(minlb >= ub)) {
                    ub = this.sqrtdistance(cur, this.means[prev]);
                    this.upper.put((DBIDRef)it, ub);
                    if (!(minlb >= ub)) {
                        int best = prev;
                        for (int g = 0; g < t; ++g) {
                            double lb = lbs[g];
                            if (lb >= ub) continue;
                            double plb = prevlb[g];
                            double sc = Double.POSITIVE_INFINITY;
                            for (int i : this.groups[g]) {
                                double di;
                                if (i == prev || sc < plb - this.cdrift[i] || !((di = this.sqrtdistance(cur, this.means[i])) < sc)) continue;
                                if (di < ub) {
                                    lb = sc = ub;
                                    ub = di;
                                    best = i;
                                    continue;
                                }
                                sc = di;
                            }
                            lbs[g] = sc;
                        }
                        if (prev != best) {
                            this.upper.put((DBIDRef)it, ub);
                            ((ModifiableDBIDs)this.clusters.get(this.assignment.intValue((DBIDRef)it))).remove((DBIDRef)it);
                            ((ModifiableDBIDs)this.clusters.get(best)).add((DBIDRef)it);
                            AbstractKMeans.plusMinusEquals(this.sums[best], this.sums[prev], cur);
                            this.assignment.put((DBIDRef)it, best);
                            ++changed;
                        }
                    }
                }
                it.advance();
            }
            return changed;
        }

        private int initialAssignToNearestCluster() {
            assert (this.k == this.means.length);
            DBIDIter id = this.relation.iterDBIDs();
            while (id.valid()) {
                NumberVector point = (NumberVector)this.relation.get((DBIDRef)id);
                double[] lower = (double[])this.lower.get((DBIDRef)id);
                double min = Double.POSITIVE_INFINITY;
                int globalindex = 0;
                for (int g = 0; g < this.groups.length; ++g) {
                    int[] group = this.groups[g];
                    if (group.length == 0) continue;
                    double min1 = this.distance(point, this.means[group[0]]);
                    double min2 = Double.POSITIVE_INFINITY;
                    int best = group[0];
                    for (int c = 1; c < group.length; ++c) {
                        int center = group[c];
                        double dist = this.distance(point, this.means[center]);
                        if (dist < min1) {
                            min2 = min1;
                            best = center;
                            min1 = dist;
                            continue;
                        }
                        if (!(dist < min2)) continue;
                        min2 = dist;
                    }
                    double d = min1 = this.isSquared ? Math.sqrt(min1) : min1;
                    double d2 = min2 < Double.POSITIVE_INFINITY ? (this.isSquared ? Math.sqrt(min2) : min2) : (min2 = min1);
                    if (min1 < min) {
                        if (globalindex != -1) {
                            lower[this.glabel[globalindex]] = min;
                        }
                        min = min1;
                        globalindex = best;
                        lower[g] = min2;
                        continue;
                    }
                    lower[g] = min1;
                }
                ((ModifiableDBIDs)this.clusters.get(globalindex)).add((DBIDRef)id);
                this.assignment.put((DBIDRef)id, globalindex);
                this.upper.put((DBIDRef)id, min);
                AbstractKMeans.plusEquals(this.sums[globalindex], point);
                id.advance();
            }
            return this.relation.size();
        }

        @Override
        protected Logging getLogger() {
            return LOG;
        }
    }
}

