/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.kmeans.spherical;

import elki.clustering.kmeans.AbstractKMeans;
import elki.clustering.kmeans.initialization.KMeansInitialization;
import elki.data.Clustering;
import elki.data.NumberVector;
import elki.data.VectorUtil;
import elki.data.model.KMeansModel;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDs;
import elki.database.ids.ModifiableDBIDs;
import elki.database.relation.Relation;
import elki.database.relation.RelationUtil;
import elki.distance.CosineDistance;
import elki.distance.NumberVectorDistance;
import elki.logging.Logging;
import elki.math.linearalgebra.VMath;
import elki.utilities.documentation.Reference;
import elki.utilities.optionhandling.parameterization.Parameterization;
import java.util.Arrays;
import java.util.List;

@Reference(authors="I. S. Dhillon, D. S. Modha", title="Concept Decompositions for Large Sparse Text Data Using Clustering", booktitle="Machine Learning 42", url="https://doi.org/10.1023/A:1007612920971", bibkey="DBLP:journals/ml/DhillonM01")
public class SphericalKMeans<V extends NumberVector>
extends AbstractKMeans<V, KMeansModel> {
    private static final Logging LOG = Logging.getLogger(SphericalKMeans.class);

    public SphericalKMeans(int k, int maxiter, KMeansInitialization initializer) {
        super(CosineDistance.STATIC, k, maxiter, initializer);
    }

    @Override
    public Clustering<KMeansModel> run(Relation<V> relation) {
        Instance instance = new Instance(relation, this.initialMeans(relation));
        instance.run(this.maxiter);
        return instance.buildResult();
    }

    @Override
    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array((TypeInformation[])new TypeInformation[]{this.distance.getInputTypeRestriction()});
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Par<V extends NumberVector>
    extends AbstractKMeans.Par<V> {
        @Override
        public void configure(Parameterization config) {
            this.getParameterK(config);
            this.getParameterInitialization(config);
            this.getParameterMaxIter(config);
        }

        @Override
        public SphericalKMeans<V> make() {
            return new SphericalKMeans(this.k, this.maxiter, this.initializer);
        }
    }

    public static class Instance
    extends AbstractKMeans.Instance {
        public Instance(Relation<? extends NumberVector> relation, double[][] means) {
            super(relation, (NumberVectorDistance<?>)CosineDistance.STATIC, means);
            int dim = RelationUtil.maxDimensionality(relation);
            for (int i = 0; i < means.length; ++i) {
                if (means[i].length >= dim) continue;
                means[i] = Arrays.copyOf(means[i], dim);
            }
        }

        @Override
        public int iterate(int iteration) {
            this.means = iteration == 1 ? this.means : Instance.means(this.clusters, this.means, (Relation<? extends NumberVector>)this.relation);
            return this.assignToNearestCluster();
        }

        @Override
        protected int assignToNearestCluster() {
            assert (this.k == this.means.length);
            int changed = 0;
            Arrays.fill(this.varsum, 0.0);
            for (ModifiableDBIDs cluster : this.clusters) {
                cluster.clear();
            }
            DBIDIter iditer = this.relation.iterDBIDs();
            while (iditer.valid()) {
                NumberVector fv = (NumberVector)this.relation.get((DBIDRef)iditer);
                double maxSim = VectorUtil.dot((NumberVector)fv, (double[])this.means[0]);
                ++this.diststat;
                int maxIndex = 0;
                for (int i = 1; i < this.k; ++i) {
                    double sim = VectorUtil.dot((NumberVector)fv, (double[])this.means[i]);
                    ++this.diststat;
                    if (!(sim > maxSim)) continue;
                    maxIndex = i;
                    maxSim = sim;
                }
                int n = maxIndex;
                this.varsum[n] = this.varsum[n] + (maxSim < 1.0 ? 2.0 * (1.0 - maxSim) : 0.0);
                ((ModifiableDBIDs)this.clusters.get(maxIndex)).add((DBIDRef)iditer);
                if (this.assignment.putInt((DBIDRef)iditer, maxIndex) != maxIndex) {
                    ++changed;
                }
                iditer.advance();
            }
            return changed;
        }

        protected double similarity(NumberVector vec1, double[] vec2) {
            ++this.diststat;
            return Math.min(1.0, VectorUtil.dot((NumberVector)vec1, (double[])vec2));
        }

        protected double similarity(double[] vec1, double[] vec2) {
            ++this.diststat;
            return Math.min(1.0, VMath.dot((double[])vec1, (double[])vec2));
        }

        @Override
        protected double distance(double[] x, double[] y) {
            ++this.diststat;
            double d = 0.0;
            for (int i = 0; i < x.length; ++i) {
                double v = x[i] - y[i];
                d += v * v;
            }
            return d > 0.0 ? d : 0.0;
        }

        @Override
        protected double distance(NumberVector x, double[] y) {
            ++this.diststat;
            double s = VectorUtil.dot((NumberVector)x, (double[])y);
            return s < 1.0 ? 2.0 - 2.0 * s : 0.0;
        }

        @Override
        protected double distance(NumberVector x, NumberVector y) {
            ++this.diststat;
            double s = VectorUtil.dot((NumberVector)x, (NumberVector)y);
            return s < 1.0 ? 2.0 - 2.0 * s : 0.0;
        }

        @Override
        protected double sqrtdistance(NumberVector x, double[] y) {
            ++this.diststat;
            double s = VectorUtil.dot((NumberVector)x, (double[])y);
            return s < 1.0 ? Math.sqrt(2.0 - 2.0 * s) : 0.0;
        }

        @Override
        protected double sqrtdistance(NumberVector x, NumberVector y) {
            ++this.diststat;
            double s = VectorUtil.dot((NumberVector)x, (NumberVector)y);
            return s < 1.0 ? Math.sqrt(2.0 - 2.0 * s) : 0.0;
        }

        protected void initialSeparation(double[][] ccsim) {
            int k = this.means.length;
            for (int i = 1; i < k; ++i) {
                double[] mi = this.means[i];
                for (int j = 0; j < i; ++j) {
                    double s = this.similarity(mi, this.means[j]);
                    double d = s > -1.0 ? Math.sqrt((s + 1.0) * 0.5) : 0.0;
                    ccsim[j][i] = d;
                    ccsim[i][j] = d;
                }
            }
        }

        protected void movedSimilarity(double[][] means, double[][] newmeans, double[] sims) {
            assert (newmeans.length == means.length && sims.length == means.length);
            for (int i = 0; i < means.length; ++i) {
                sims[i] = this.similarity(means[i], newmeans[i]);
            }
        }

        @Override
        protected void meansFromSums(double[][] dst, double[][] sums, double[][] prev) {
            for (int i = 0; i < dst.length; ++i) {
                double w = VMath.euclideanLength((double[])sums[i]);
                if (!(w > 1.0E-7)) {
                    System.arraycopy(prev[i], 0, dst[i], 0, prev[i].length);
                    continue;
                }
                VMath.overwriteTimes((double[])dst[i], (double[])sums[i], (double)(1.0 / w));
            }
        }

        @Override
        protected void recomputeVariance(Relation<? extends NumberVector> relation) {
            Arrays.fill(this.varsum, 0.0);
            for (int i = 0; i < this.clusters.size(); ++i) {
                DBIDs ids = (DBIDs)this.clusters.get(i);
                double ssum = 0.0;
                double[] mean = this.means[i];
                DBIDIter it = ids.iter();
                while (it.valid()) {
                    ssum += Math.min(1.0, VectorUtil.dot((NumberVector)((NumberVector)relation.get((DBIDRef)it)), (double[])mean));
                    ++this.diststat;
                    it.advance();
                }
                this.varsum[i] = 2.0 * ((double)ids.size() - ssum);
            }
        }

        @Override
        protected Logging getLogger() {
            return LOG;
        }

        protected static double[][] means(List<? extends DBIDs> clusters, double[][] means, Relation<? extends NumberVector> relation) {
            int k = means.length;
            int dim = means[0].length;
            double[][] newMeans = new double[k][];
            for (int i = 0; i < k; ++i) {
                DBIDs list = clusters.get(i);
                if (list.isEmpty()) {
                    newMeans[i] = means[i];
                    continue;
                }
                double[] sum = new double[dim];
                DBIDIter iter = list.iter();
                while (iter.valid()) {
                    AbstractKMeans.plusEquals(sum, (NumberVector)relation.get((DBIDRef)iter));
                    iter.advance();
                }
                newMeans[i] = VMath.normalizeEquals((double[])sum);
            }
            return newMeans;
        }
    }
}

