/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.kmeans.initialization;

import elki.clustering.kmeans.initialization.AbstractKMeansInitialization;
import elki.clustering.kmedoids.initialization.KMedoidsInitialization;
import elki.data.NumberVector;
import elki.database.datastore.DataStoreUtil;
import elki.database.datastore.WritableDoubleDataStore;
import elki.database.ids.ArrayModifiableDBIDs;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDVar;
import elki.database.ids.DBIDs;
import elki.database.query.distance.DistanceQuery;
import elki.database.relation.Relation;
import elki.distance.NumberVectorDistance;
import elki.logging.Logging;
import elki.logging.statistics.LongStatistic;
import elki.logging.statistics.Statistic;
import elki.utilities.documentation.Reference;
import elki.utilities.documentation.Title;
import elki.utilities.random.RandomFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

@Title(value="K-means++")
@Reference(authors="D. Arthur, S. Vassilvitskii", title="k-means++: the advantages of careful seeding", booktitle="Proc. 18th Annual ACM-SIAM Symposium on Discrete Algorithms (SODA 2007)", url="http://dl.acm.org/citation.cfm?id=1283383.1283494", bibkey="DBLP:conf/soda/ArthurV07")
public class KMeansPlusPlus<O>
extends AbstractKMeansInitialization
implements KMedoidsInitialization<O> {
    private static final Logging LOG = Logging.getLogger(KMeansPlusPlus.class);

    public KMeansPlusPlus(RandomFactory rnd) {
        super(rnd);
    }

    @Override
    public double[][] chooseInitialMeans(Relation<? extends NumberVector> relation, int k, NumberVectorDistance<?> distance) {
        if (relation.size() < k) {
            throw new IllegalArgumentException("Cannot choose k=" + k + " means from N=" + relation.size() + " < k objects.");
        }
        return new NumberVectorInstance(relation, distance, this.rnd).run(k);
    }

    @Override
    public DBIDs chooseInitialMedoids(int k, DBIDs ids, DistanceQuery<? super O> distQ) {
        if (ids.size() < k) {
            throw new IllegalArgumentException("Cannot choose k=" + k + " means from N=" + ids.size() + " < k objects.");
        }
        return new MedoidsInstance(ids, distQ, this.rnd).run(k);
    }

    public static class Par<V>
    extends AbstractKMeansInitialization.Par {
        public KMeansPlusPlus<V> make() {
            return new KMeansPlusPlus(this.rnd);
        }
    }

    protected static class MedoidsInstance
    extends Instance<DBIDRef> {
        DistanceQuery<?> distQ;

        public MedoidsInstance(DBIDs ids, DistanceQuery<?> distQ, RandomFactory rnd) {
            super(ids, rnd);
            this.distQ = distQ;
        }

        public DBIDs run(int k) {
            ArrayModifiableDBIDs means = DBIDUtil.newArray((int)k);
            DBIDVar first = DBIDUtil.randomSample((DBIDs)this.ids, (Random)this.random);
            means.add((DBIDRef)first);
            this.chooseRemaining(k, means, this.initialWeights(first));
            this.weights.destroy();
            LOG.statistics((Statistic)new LongStatistic(KMeansPlusPlus.class.getName() + ".distance-computations", this.diststat));
            return means;
        }

        @Override
        protected double distance(DBIDRef a, DBIDRef b) {
            ++this.diststat;
            return this.distQ.distance(a, b);
        }

        protected void chooseRemaining(int k, ArrayModifiableDBIDs means, double weightsum) {
            while (true) {
                if (weightsum > Double.MAX_VALUE) {
                    throw new IllegalStateException("Could not choose a reasonable mean - too many data points, too large distance sum?");
                }
                if (weightsum < Double.MIN_NORMAL) {
                    LOG.warning((CharSequence)"Could not choose a reasonable mean - to few unique data points?");
                }
                double r = this.nextDouble(weightsum);
                DBIDIter it = this.ids.iter();
                while (it.valid()) {
                    double d;
                    r -= this.weights.doubleValue((DBIDRef)it);
                    if (d <= 0.0) break;
                    it.advance();
                }
                if (!it.valid()) {
                    weightsum -= r;
                    continue;
                }
                means.add((DBIDRef)it);
                if (means.size() >= k) break;
                this.weights.putDouble((DBIDRef)it, 0.0);
                weightsum = this.updateWeights(it);
            }
        }
    }

    protected static class NumberVectorInstance
    extends Instance<NumberVector> {
        protected NumberVectorDistance<?> distance;
        protected Relation<? extends NumberVector> relation;

        public NumberVectorInstance(Relation<? extends NumberVector> relation, NumberVectorDistance<?> distance, RandomFactory rnd) {
            super(relation.getDBIDs(), rnd);
            this.distance = distance;
            this.relation = relation;
        }

        public double[][] run(int k) {
            ArrayList<NumberVector> means = new ArrayList<NumberVector>(k);
            NumberVector firstvec = (NumberVector)this.relation.get((DBIDRef)DBIDUtil.randomSample((DBIDs)this.ids, (Random)this.random));
            means.add(firstvec);
            this.chooseRemaining(k, means, this.initialWeights(firstvec));
            this.weights.destroy();
            LOG.statistics((Statistic)new LongStatistic(KMeansPlusPlus.class.getName() + ".distance-computations", this.diststat));
            return AbstractKMeansInitialization.unboxVectors(means);
        }

        @Override
        protected double distance(NumberVector a, DBIDRef b) {
            ++this.diststat;
            return this.distance.distance(a, (NumberVector)this.relation.get(b));
        }

        protected void chooseRemaining(int k, List<NumberVector> means, double weightsum) {
            while (true) {
                if (weightsum > Double.MAX_VALUE) {
                    throw new IllegalStateException("Could not choose a reasonable mean - too many data points, too large distance sum?");
                }
                if (weightsum < Double.MIN_NORMAL) {
                    LOG.warning((CharSequence)"Could not choose a reasonable mean - to few unique data points?");
                }
                double r = this.nextDouble(weightsum);
                DBIDIter it = this.ids.iter();
                while (it.valid()) {
                    double d;
                    r -= this.weights.doubleValue((DBIDRef)it);
                    if (d <= 0.0) break;
                    it.advance();
                }
                if (!it.valid()) {
                    weightsum -= r;
                    continue;
                }
                NumberVector newmean = (NumberVector)this.relation.get((DBIDRef)it);
                means.add(newmean);
                if (means.size() >= k) break;
                this.weights.putDouble((DBIDRef)it, 0.0);
                weightsum = this.updateWeights(newmean);
            }
        }
    }

    protected static abstract class Instance<T> {
        protected DBIDs ids;
        protected WritableDoubleDataStore weights;
        protected long diststat;
        protected Random random;

        public Instance(DBIDs ids, RandomFactory rnd) {
            this.ids = ids;
            this.random = rnd.getSingleThreadedRandom();
            this.weights = DataStoreUtil.makeDoubleStorage((DBIDs)ids, (int)3, (double)0.0);
        }

        protected abstract double distance(T var1, DBIDRef var2);

        protected double initialWeights(T first) {
            double weightsum = 0.0;
            DBIDIter it = this.ids.iter();
            while (it.valid()) {
                double weight = this.distance(first, (DBIDRef)it);
                this.weights.putDouble((DBIDRef)it, weight);
                weightsum += weight;
                it.advance();
            }
            return weightsum;
        }

        protected double updateWeights(T latest) {
            double weightsum = 0.0;
            DBIDIter it = this.ids.iter();
            while (it.valid()) {
                double weight = this.weights.doubleValue((DBIDRef)it);
                if (!(weight <= 0.0)) {
                    double newweight = this.distance(latest, (DBIDRef)it);
                    if (newweight < weight) {
                        this.weights.putDouble((DBIDRef)it, newweight);
                        weight = newweight;
                    }
                    weightsum += weight;
                }
                it.advance();
            }
            return weightsum;
        }

        protected double nextDouble(double weightsum) {
            double r = this.random.nextDouble() * weightsum;
            while (r <= 0.0 && weightsum > Double.MIN_NORMAL) {
                r = this.random.nextDouble() * weightsum;
            }
            return r;
        }
    }
}

