/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.kmeans.initialization;

import elki.clustering.kmeans.initialization.AbstractKMeansInitialization;
import elki.data.NumberVector;
import elki.database.datastore.DataStoreUtil;
import elki.database.datastore.WritableDoubleDataStore;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.relation.Relation;
import elki.distance.NumberVectorDistance;
import elki.logging.Logging;
import elki.logging.statistics.LongStatistic;
import elki.logging.statistics.Statistic;
import elki.utilities.documentation.Reference;
import elki.utilities.documentation.Title;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.constraints.ParameterConstraint;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.IntParameter;
import elki.utilities.random.RandomFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

@Title(value="K-MC\u00b2")
@Reference(authors="O. Bachem, M. Lucic, S. H. Hassani, A. Krause", title="Approximate K-Means++ in Sublinear Time", booktitle="Proc. 30th AAAI Conference on Artificial Intelligence", url="http://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/view/12147", bibkey="DBLP:conf/aaai/BachemLHK16")
public class KMC2
extends AbstractKMeansInitialization {
    private static final Logging LOG = Logging.getLogger(KMC2.class);
    protected int m;

    public KMC2(int m, RandomFactory rnd) {
        super(rnd);
        this.m = m;
    }

    @Override
    public double[][] chooseInitialMeans(Relation<? extends NumberVector> relation, int k, NumberVectorDistance<?> distance) {
        if (relation.size() < k) {
            throw new IllegalArgumentException("Cannot choose k=" + k + " means from N=" + relation.size() + " < k objects.");
        }
        return new Instance(relation, distance, this.m, this.rnd).run(k);
    }

    public static class Par
    extends AbstractKMeansInitialization.Par {
        public static final OptionID M_ID = new OptionID("afkmc2.m", "Number of MCMC steps to do");
        protected int m;

        @Override
        public void configure(Parameterization config) {
            super.configure(config);
            ((IntParameter)new IntParameter(M_ID, 100).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_INT)).grab(config, x -> {
                this.m = x;
            });
        }

        public KMC2 make() {
            return new KMC2(this.m, this.rnd);
        }
    }

    protected static class Instance {
        protected Relation<? extends NumberVector> relation;
        protected NumberVectorDistance<?> distance;
        protected WritableDoubleDataStore weights;
        protected long diststat;
        protected int m;
        protected Random random;

        public Instance(Relation<? extends NumberVector> relation, NumberVectorDistance<?> distance, int m, RandomFactory rnd) {
            this.relation = relation;
            this.distance = distance;
            this.m = m;
            this.random = rnd.getSingleThreadedRandom();
            this.weights = DataStoreUtil.makeDoubleStorage((DBIDs)relation.getDBIDs(), (int)3, (double)0.0);
        }

        protected double initialWeights(NumberVector first) {
            double weightsum = 0.0;
            DBIDIter it = this.relation.iterDBIDs();
            while (it.valid()) {
                double weight = this.distance(first, (DBIDRef)it);
                this.weights.putDouble((DBIDRef)it, weight);
                weightsum += weight;
                it.advance();
            }
            return weightsum;
        }

        public double[][] run(int k) {
            ArrayList<NumberVector> means = new ArrayList<NumberVector>(k);
            NumberVector firstvec = (NumberVector)this.relation.get((DBIDRef)DBIDUtil.randomSample((DBIDs)this.relation.getDBIDs(), (Random)this.random));
            means.add(firstvec);
            this.chooseRemaining(k, means, this.initialWeights(firstvec));
            this.weights.destroy();
            this.getLogger().statistics((Statistic)new LongStatistic(this.getClass().getName() + ".distance-computations", this.diststat));
            return AbstractKMeansInitialization.unboxVectors(means);
        }

        protected double distance(NumberVector a, DBIDRef b) {
            ++this.diststat;
            return this.distance.distance(a, (NumberVector)this.relation.get(b));
        }

        protected void chooseRemaining(int k, List<NumberVector> means, double weightsum) {
            while (means.size() < k) {
                DBIDRef best = this.sample(weightsum);
                double curp = this.distance(best, means) / this.weights.doubleValue(best);
                for (int i = 1; i < this.m; ++i) {
                    DBIDRef cand = this.sample(weightsum);
                    double candp = this.distance(cand, means) / this.weights.doubleValue(cand);
                    if (curp > 0.0 && !(candp / curp > this.random.nextDouble())) continue;
                    best = cand;
                    curp = candp;
                }
                means.add((NumberVector)this.relation.get(best));
            }
        }

        protected DBIDRef sample(double weightsum) {
            DBIDIter it;
            while (true) {
                if (weightsum > Double.MAX_VALUE) {
                    throw new IllegalStateException("Could not choose a reasonable mean - too many data points, too large distance sum?");
                }
                if (weightsum < Double.MIN_NORMAL) {
                    LOG.warning((CharSequence)"Could not choose a reasonable mean - to few unique data points?");
                }
                double r = this.random.nextDouble() * weightsum;
                it = this.relation.iterDBIDs();
                while (it.valid()) {
                    double d;
                    r -= this.weights.doubleValue((DBIDRef)it);
                    if (d <= 0.0) break;
                    it.advance();
                }
                if (it.valid()) break;
                weightsum -= r;
            }
            return it;
        }

        protected double distance(DBIDRef cand, List<NumberVector> means) {
            double d = this.weights.doubleValue(cand);
            for (int i = 1; i < means.size(); ++i) {
                double d2 = this.distance(means.get(i), cand);
                d = d2 < d ? d2 : d;
            }
            return d;
        }

        protected Logging getLogger() {
            return LOG;
        }
    }
}

