/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.kmeans.initialization;

import elki.clustering.kmeans.initialization.AFKMC2;
import elki.clustering.kmeans.initialization.SphericalKMeansPlusPlus;
import elki.data.NumberVector;
import elki.data.VectorUtil;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.relation.Relation;
import elki.distance.ArcCosineDistance;
import elki.distance.ArcCosineUnitlengthDistance;
import elki.distance.CosineDistance;
import elki.distance.CosineUnitlengthDistance;
import elki.distance.NumberVectorDistance;
import elki.logging.Logging;
import elki.utilities.documentation.Reference;
import elki.utilities.documentation.Title;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.constraints.ParameterConstraint;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.random.RandomFactory;
import java.util.List;

@Title(value="Spherical AFK-MC\u00b2")
@Reference(authors="R. Pratap, A. A. Deshmukh, P. Nair, T. Dutt", title="A Faster Sampling Algorithm for Spherical k-means", booktitle="Proc. 10th Asian Conference on Machine Learning, ACML", url="http://proceedings.mlr.press/v95/pratap18a.html", bibkey="DBLP:conf/acml/PratapDND18")
public class SphericalAFKMC2
extends AFKMC2 {
    private static final Logging LOG = Logging.getLogger(SphericalAFKMC2.class);
    protected double alpha;

    public SphericalAFKMC2(int m, double alpha, RandomFactory rnd) {
        super(m, rnd);
        this.alpha = alpha;
    }

    @Override
    public double[][] chooseInitialMeans(Relation<? extends NumberVector> relation, int k, NumberVectorDistance<?> distance) {
        if (relation.size() < k) {
            throw new IllegalArgumentException("Cannot choose k=" + k + " means from N=" + relation.size() + " < k objects.");
        }
        if (distance instanceof CosineDistance || distance instanceof CosineUnitlengthDistance || distance instanceof ArcCosineDistance || distance instanceof ArcCosineUnitlengthDistance) {
            return new Instance(relation, this.m, this.alpha, this.rnd).run(k);
        }
        LOG.warning((CharSequence)("Spherical k-means++ was used with an instance of " + distance.getClass() + ". Falling back to regular k-means++."));
        return new AFKMC2.Instance(relation, distance, this.m, this.rnd).run(k);
    }

    public static class Par
    extends AFKMC2.Par {
        public static final OptionID ALPHA_ID = SphericalKMeansPlusPlus.Par.ALPHA_ID;
        protected double alpha;

        @Override
        public void configure(Parameterization config) {
            super.configure(config);
            ((DoubleParameter)new DoubleParameter(ALPHA_ID, 1.5).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_DOUBLE)).grab(config, x -> {
                this.alpha = x;
            });
        }

        @Override
        public SphericalAFKMC2 make() {
            return new SphericalAFKMC2(this.m, this.alpha, this.rnd);
        }
    }

    protected static class Instance
    extends AFKMC2.Instance {
        protected double alpha;

        public Instance(Relation<? extends NumberVector> relation, int m, double alpha, RandomFactory rnd) {
            super(relation, (NumberVectorDistance<?>)CosineDistance.STATIC, m, rnd);
            this.alpha = alpha;
        }

        @Override
        protected double initialWeights(NumberVector first) {
            double weightsum = 0.0;
            DBIDIter it = this.relation.iterDBIDs();
            while (it.valid()) {
                double weight = this.alpha - this.similarity(first, (DBIDRef)it);
                this.weights.putDouble((DBIDRef)it, weight);
                weightsum += weight;
                it.advance();
            }
            return weightsum;
        }

        protected double similarity(NumberVector a, DBIDRef b) {
            ++this.diststat;
            return VectorUtil.dot((NumberVector)a, (NumberVector)((NumberVector)this.relation.get(b)));
        }

        @Override
        protected double distance(DBIDRef cand, List<NumberVector> means) {
            double d = this.weights.doubleValue(cand);
            for (int i = 1; i < means.size(); ++i) {
                double d2 = this.alpha - this.similarity(means.get(i), cand);
                d = d2 < d ? d2 : d;
            }
            return d;
        }
    }
}

