/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.kmeans.initialization;

import elki.clustering.kmeans.initialization.AbstractKMeansInitialization;
import elki.clustering.kmeans.initialization.KMeansPlusPlus;
import elki.data.NumberVector;
import elki.data.VectorUtil;
import elki.database.datastore.DataStoreUtil;
import elki.database.datastore.WritableDoubleDataStore;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.relation.Relation;
import elki.distance.ArcCosineDistance;
import elki.distance.ArcCosineUnitlengthDistance;
import elki.distance.CosineDistance;
import elki.distance.CosineUnitlengthDistance;
import elki.distance.NumberVectorDistance;
import elki.logging.Logging;
import elki.logging.statistics.LongStatistic;
import elki.logging.statistics.Statistic;
import elki.utilities.documentation.Reference;
import elki.utilities.documentation.Title;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.constraints.ParameterConstraint;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.random.RandomFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

@Title(value="Spherical K-means++")
@Reference(authors="Y. Endo and S. Miyamoto", title="Spherical k-Means++ Clustering", booktitle="Modeling Decisions for Artificial Intelligence", url="https://doi.org/10.1007/978-3-319-23240-9_9", bibkey="DBLP:conf/mdai/EndoM15")
public class SphericalKMeansPlusPlus<O>
extends AbstractKMeansInitialization {
    private static final Logging LOG = Logging.getLogger(SphericalKMeansPlusPlus.class);
    protected double alpha;

    public SphericalKMeansPlusPlus(double alpha, RandomFactory rnd) {
        super(rnd);
        this.alpha = alpha;
    }

    @Override
    public double[][] chooseInitialMeans(Relation<? extends NumberVector> relation, int k, NumberVectorDistance<?> distance) {
        if (relation.size() < k) {
            throw new IllegalArgumentException("Cannot choose k=" + k + " means from N=" + relation.size() + " < k objects.");
        }
        if (distance instanceof CosineDistance || distance instanceof CosineUnitlengthDistance || distance instanceof ArcCosineDistance || distance instanceof ArcCosineUnitlengthDistance) {
            return new Instance(relation, this.alpha, this.rnd).run(k);
        }
        LOG.warning((CharSequence)("Spherical k-means++ was used with an instance of " + distance.getClass() + ". Falling back to regular k-means++."));
        return new KMeansPlusPlus.NumberVectorInstance(relation, distance, this.rnd).run(k);
    }

    public static class Par<V>
    extends AbstractKMeansInitialization.Par {
        public static final OptionID ALPHA_ID = new OptionID("skmpp.alpha", "Alpha parameter for alpha-SKM, usually 1.5 to achieve triangular inequality.");
        protected double alpha;

        @Override
        public void configure(Parameterization config) {
            super.configure(config);
            ((DoubleParameter)new DoubleParameter(ALPHA_ID, 1.5).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_DOUBLE)).grab(config, x -> {
                this.alpha = x;
            });
        }

        public SphericalKMeansPlusPlus<V> make() {
            return new SphericalKMeansPlusPlus(this.alpha, this.rnd);
        }
    }

    protected static class Instance {
        protected double alpha;
        protected Relation<? extends NumberVector> relation;
        protected WritableDoubleDataStore weights;
        protected long diststat;
        protected Random random;

        public Instance(Relation<? extends NumberVector> relation, double alpha, RandomFactory rnd) {
            this.relation = relation;
            this.alpha = alpha;
            this.random = rnd.getSingleThreadedRandom();
            this.weights = DataStoreUtil.makeDoubleStorage((DBIDs)relation.getDBIDs(), (int)3, (double)0.0);
        }

        public double[][] run(int k) {
            ArrayList<NumberVector> means = new ArrayList<NumberVector>(k);
            NumberVector firstvec = (NumberVector)this.relation.get((DBIDRef)DBIDUtil.randomSample((DBIDs)this.relation.getDBIDs(), (Random)this.random));
            means.add(firstvec);
            this.chooseRemaining(k, means, this.initialWeights(firstvec));
            this.weights.destroy();
            LOG.statistics((Statistic)new LongStatistic(SphericalKMeansPlusPlus.class.getName() + ".distance-computations", this.diststat));
            return AbstractKMeansInitialization.unboxVectors(means);
        }

        protected double similarity(NumberVector a, DBIDRef b) {
            ++this.diststat;
            return VectorUtil.dot((NumberVector)a, (NumberVector)((NumberVector)this.relation.get(b)));
        }

        protected void chooseRemaining(int k, List<NumberVector> means, double weightsum) {
            while (true) {
                if (weightsum > Double.MAX_VALUE) {
                    throw new IllegalStateException("Could not choose a reasonable mean - too many data points, too large distance sum?");
                }
                if (weightsum < Double.MIN_NORMAL) {
                    LOG.warning((CharSequence)"Could not choose a reasonable mean - to few unique data points?");
                }
                double r = this.nextDouble(weightsum);
                DBIDIter it = this.relation.iterDBIDs();
                while (it.valid()) {
                    double d;
                    r -= this.weights.doubleValue((DBIDRef)it);
                    if (d <= 0.0) break;
                    it.advance();
                }
                if (!it.valid()) {
                    weightsum -= r;
                    continue;
                }
                NumberVector newmean = (NumberVector)this.relation.get((DBIDRef)it);
                means.add(newmean);
                if (means.size() >= k) break;
                this.weights.putDouble((DBIDRef)it, 0.0);
                weightsum = this.updateWeights(newmean);
            }
        }

        protected double initialWeights(NumberVector first) {
            double weightsum = 0.0;
            DBIDIter it = this.relation.iterDBIDs();
            while (it.valid()) {
                double weight = this.alpha - this.similarity(first, (DBIDRef)it);
                this.weights.putDouble((DBIDRef)it, weight);
                weightsum += weight;
                it.advance();
            }
            return weightsum;
        }

        protected double updateWeights(NumberVector latest) {
            double weightsum = 0.0;
            DBIDIter it = this.relation.iterDBIDs();
            while (it.valid()) {
                double weight = this.weights.doubleValue((DBIDRef)it);
                if (!(weight <= 0.0)) {
                    double newweight = this.alpha - this.similarity(latest, (DBIDRef)it);
                    if (newweight < weight) {
                        this.weights.putDouble((DBIDRef)it, newweight);
                        weight = newweight;
                    }
                    weightsum += weight;
                }
                it.advance();
            }
            return weightsum;
        }

        protected double nextDouble(double weightsum) {
            double r = this.random.nextDouble() * weightsum;
            while (r <= 0.0 && weightsum > Double.MIN_NORMAL) {
                r = this.random.nextDouble() * weightsum;
            }
            return r;
        }
    }
}

