/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.kmeans.spherical;

import elki.clustering.kmeans.AbstractKMeans;
import elki.clustering.kmeans.initialization.KMeansInitialization;
import elki.clustering.kmeans.spherical.SphericalKMeans;
import elki.data.Clustering;
import elki.data.NumberVector;
import elki.data.model.KMeansModel;
import elki.database.datastore.DataStoreUtil;
import elki.database.datastore.WritableDoubleDataStore;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDs;
import elki.database.ids.ModifiableDBIDs;
import elki.database.relation.Relation;
import elki.database.relation.RelationUtil;
import elki.logging.Logging;
import elki.utilities.documentation.Reference;
import elki.utilities.documentation.References;
import elki.utilities.optionhandling.parameterization.Parameterization;

@References(value={@Reference(authors="Erich Schubert, Andreas Lang, Gloria Feher", title="Accelerating Spherical k-Means", booktitle="Int. Conf. on Similarity Search and Applications, SISAP 2021", url="https://doi.org/10.1007/978-3-030-89657-7_17", bibkey="DBLP:conf/sisap/SchubertLF21"), @Reference(authors="Erich Schubert", title="A Triangle Inequality for Cosine Similarity", booktitle="Int. Conf. on Similarity Search and Applications, SISAP 2021", url="https://doi.org/10.1007/978-3-030-89657-7_3", bibkey="DBLP:conf/sisap/Schubert21")})
public class SphericalSimplifiedHamerlyKMeans<V extends NumberVector>
extends SphericalKMeans<V> {
    private static final Logging LOG = Logging.getLogger(SphericalSimplifiedHamerlyKMeans.class);
    protected boolean varstat;

    public SphericalSimplifiedHamerlyKMeans(int k, int maxiter, KMeansInitialization initializer, boolean varstat) {
        super(k, maxiter, initializer);
        this.varstat = varstat;
    }

    @Override
    public Clustering<KMeansModel> run(Relation<V> relation) {
        Instance instance = new Instance(relation, this.initialMeans(relation));
        instance.run(this.maxiter);
        return instance.buildResult(this.varstat, relation);
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Par<V extends NumberVector>
    extends SphericalKMeans.Par<V> {
        @Override
        public void configure(Parameterization config) {
            super.configure(config);
            super.getParameterVarstat(config);
        }

        @Override
        public SphericalSimplifiedHamerlyKMeans<V> make() {
            return new SphericalSimplifiedHamerlyKMeans(this.k, this.maxiter, this.initializer, this.varstat);
        }
    }

    protected static class Instance
    extends SphericalKMeans.Instance {
        double[][] sums;
        double[][] newmeans;
        WritableDoubleDataStore lsim;
        WritableDoubleDataStore usim;
        double[] csim;

        public Instance(Relation<? extends NumberVector> relation, double[][] means) {
            super(relation, means);
            this.lsim = DataStoreUtil.makeDoubleStorage((DBIDs)relation.getDBIDs(), (int)3, (double)0.0);
            this.usim = DataStoreUtil.makeDoubleStorage((DBIDs)relation.getDBIDs(), (int)3, (double)2.0);
            int dim = RelationUtil.maxDimensionality(relation);
            this.sums = new double[this.k][dim];
            this.newmeans = new double[this.k][dim];
            this.csim = new double[this.k];
        }

        @Override
        public int iterate(int iteration) {
            if (iteration == 1) {
                return this.initialAssignToNearestCluster();
            }
            this.meansFromSums(this.newmeans, this.sums, this.means);
            this.movedSimilarity(this.means, this.newmeans, this.csim);
            this.updateBounds(this.csim);
            this.copyMeans(this.newmeans, this.means);
            return this.assignToNearestCluster();
        }

        protected int initialAssignToNearestCluster() {
            assert (this.k == this.means.length);
            DBIDIter it = this.relation.iterDBIDs();
            while (it.valid()) {
                NumberVector fv = (NumberVector)this.relation.get((DBIDRef)it);
                double max1 = this.similarity(fv, this.means[0]);
                double max2 = -1.0;
                int maxIndex = 0;
                for (int j = 1; j < this.k; ++j) {
                    double sim = this.similarity(fv, this.means[j]);
                    if (sim > max1) {
                        maxIndex = j;
                        max2 = max1;
                        max1 = sim;
                        continue;
                    }
                    if (!(sim > max2)) continue;
                    max2 = sim;
                }
                ((ModifiableDBIDs)this.clusters.get(maxIndex)).add((DBIDRef)it);
                this.assignment.putInt((DBIDRef)it, maxIndex);
                AbstractKMeans.plusEquals(this.sums[maxIndex], fv);
                this.lsim.putDouble((DBIDRef)it, max1);
                this.usim.putDouble((DBIDRef)it, max2);
                it.advance();
            }
            return this.relation.size();
        }

        @Override
        protected int assignToNearestCluster() {
            int changed = 0;
            DBIDIter it = this.relation.iterDBIDs();
            while (it.valid()) {
                double us;
                int orig = this.assignment.intValue((DBIDRef)it);
                double ls = this.lsim.doubleValue((DBIDRef)it);
                if (!(ls >= (us = this.usim.doubleValue((DBIDRef)it)))) {
                    NumberVector fv = (NumberVector)this.relation.get((DBIDRef)it);
                    ls = this.similarity(fv, this.means[orig]);
                    this.lsim.putDouble((DBIDRef)it, ls);
                    if (!(ls >= us)) {
                        double max2 = Double.NEGATIVE_INFINITY;
                        int cur = orig;
                        for (int i = 0; i < this.k; ++i) {
                            if (i == orig) continue;
                            double sim = this.similarity(fv, this.means[i]);
                            if (sim > ls) {
                                cur = i;
                                max2 = ls;
                                ls = sim;
                                continue;
                            }
                            if (!(sim > max2)) continue;
                            max2 = sim;
                        }
                        if (cur != orig) {
                            ((ModifiableDBIDs)this.clusters.get(cur)).add((DBIDRef)it);
                            ((ModifiableDBIDs)this.clusters.get(orig)).remove((DBIDRef)it);
                            this.assignment.putInt((DBIDRef)it, cur);
                            AbstractKMeans.plusMinusEquals(this.sums[cur], this.sums[orig], fv);
                            ++changed;
                            this.lsim.putDouble((DBIDRef)it, ls);
                        }
                        this.usim.putDouble((DBIDRef)it, max2);
                    }
                }
                it.advance();
            }
            return changed;
        }

        protected void updateBounds(double[] msim) {
            int least = 0;
            double delta = msim[0];
            double delta2 = 1.0;
            for (int i = 1; i < msim.length; ++i) {
                double m = msim[i];
                if (m < delta) {
                    delta2 = delta;
                    delta = m;
                    least = i;
                    continue;
                }
                if (!(m < delta2)) continue;
                delta2 = m;
            }
            delta = 1.0 - delta * delta;
            delta2 = 1.0 - delta2 * delta2;
            DBIDIter it = this.relation.iterDBIDs();
            while (it.valid()) {
                double w2;
                int ai = this.assignment.intValue((DBIDRef)it);
                double v2 = msim[ai];
                if (v2 < 1.0) {
                    double v1 = Math.min(1.0, this.lsim.doubleValue((DBIDRef)it));
                    this.lsim.putDouble((DBIDRef)it, v1 * v2 - Math.sqrt((1.0 - v1 * v1) * (1.0 - v2 * v2)));
                }
                double d = w2 = least == ai ? delta2 : delta;
                if (w2 > 0.0) {
                    double w1 = Math.min(1.0, this.usim.doubleValue((DBIDRef)it));
                    this.usim.putDouble((DBIDRef)it, w1 + Math.sqrt((1.0 - w1 * w1) * w2));
                }
                it.advance();
            }
        }

        @Override
        protected Logging getLogger() {
            return LOG;
        }
    }
}

