/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.kmeans;

import elki.Algorithm;
import elki.clustering.kmeans.KMeans;
import elki.clustering.kmeans.initialization.KMeansInitialization;
import elki.clustering.kmeans.initialization.RandomlyChosen;
import elki.data.Cluster;
import elki.data.Clustering;
import elki.data.DoubleVector;
import elki.data.NumberVector;
import elki.data.SparseNumberVector;
import elki.data.model.KMeansModel;
import elki.data.model.Model;
import elki.data.type.CombinedTypeInformation;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.datastore.DataStoreUtil;
import elki.database.datastore.WritableIntegerDataStore;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDMIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.ids.ModifiableDBIDs;
import elki.database.relation.Relation;
import elki.distance.CosineDistance;
import elki.distance.NumberVectorDistance;
import elki.distance.PrimitiveDistance;
import elki.distance.minkowski.EuclideanDistance;
import elki.distance.minkowski.SquaredEuclideanDistance;
import elki.logging.Logging;
import elki.logging.progress.AbstractProgress;
import elki.logging.progress.IndefiniteProgress;
import elki.logging.statistics.DoubleStatistic;
import elki.logging.statistics.Duration;
import elki.logging.statistics.LongStatistic;
import elki.logging.statistics.Statistic;
import elki.math.linearalgebra.VMath;
import elki.result.Metadata;
import elki.utilities.datastructures.arrays.DoubleIntegerArrayQuickSort;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.constraints.ParameterConstraint;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.Flag;
import elki.utilities.optionhandling.parameters.IntParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public abstract class AbstractKMeans<V extends NumberVector, M extends Model>
implements KMeans<V, M> {
    protected NumberVectorDistance<? super V> distance = SquaredEuclideanDistance.STATIC;
    protected int k;
    protected int maxiter;
    protected KMeansInitialization initializer;

    public AbstractKMeans(int k, int maxiter, KMeansInitialization initializer) {
        this((NumberVectorDistance<V>)SquaredEuclideanDistance.STATIC, k, maxiter, initializer);
    }

    public AbstractKMeans(NumberVectorDistance<? super V> distance, int k, int maxiter, KMeansInitialization initializer) {
        this.distance = distance;
        this.k = k;
        this.maxiter = maxiter > 0 ? maxiter : Integer.MAX_VALUE;
        this.initializer = initializer;
    }

    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array((TypeInformation[])new TypeInformation[]{new CombinedTypeInformation(new TypeInformation[]{TypeUtil.NUMBER_VECTOR_FIELD, this.distance.getInputTypeRestriction()})});
    }

    protected double[][] initialMeans(Relation<V> relation) {
        Duration inittime = this.getLogger().newDuration(this.initializer.getClass().getName() + ".time").begin();
        double[][] means = this.initializer.chooseInitialMeans(relation, this.k, this.distance);
        this.getLogger().statistics((Statistic)inittime.end());
        return means;
    }

    protected static double[][] means(List<? extends DBIDs> clusters, double[][] means, Relation<? extends NumberVector> relation) {
        if (TypeUtil.SPARSE_VECTOR_FIELD.isAssignableFromType((TypeInformation)relation.getDataTypeInformation())) {
            Relation<? extends NumberVector> sparse = relation;
            return AbstractKMeans.sparseMeans(clusters, means, sparse);
        }
        return AbstractKMeans.denseMeans(clusters, means, relation);
    }

    private static double[][] denseMeans(List<? extends DBIDs> clusters, double[][] means, Relation<? extends NumberVector> relation) {
        int k = means.length;
        double[][] newMeans = new double[k][];
        for (int i = 0; i < newMeans.length; ++i) {
            DBIDs list = clusters.get(i);
            if (list.isEmpty()) {
                newMeans[i] = means[i];
                continue;
            }
            DBIDIter iter = list.iter();
            double[] sum = ((NumberVector)relation.get((DBIDRef)iter)).toArray();
            iter.advance();
            while (iter.valid()) {
                AbstractKMeans.plusEquals(sum, (NumberVector)relation.get((DBIDRef)iter));
                iter.advance();
            }
            newMeans[i] = VMath.timesEquals((double[])sum, (double)(1.0 / (double)list.size()));
        }
        return newMeans;
    }

    public static void plusEquals(double[] sum, NumberVector vec) {
        if (vec instanceof SparseNumberVector) {
            AbstractKMeans.sparsePlusEquals(sum, (SparseNumberVector)vec);
        } else {
            AbstractKMeans.densePlusEquals(sum, vec);
        }
    }

    private static void densePlusEquals(double[] sum, NumberVector vec) {
        for (int d = 0; d < sum.length; ++d) {
            int n = d;
            sum[n] = sum[n] + vec.doubleValue(d);
        }
    }

    private static void sparsePlusEquals(double[] sum, SparseNumberVector vec) {
        int j = vec.iter();
        while (vec.iterValid(j)) {
            int n = vec.iterDim(j);
            sum[n] = sum[n] + vec.iterDoubleValue(j);
            j = vec.iterAdvance(j);
        }
    }

    public static void minusEquals(double[] sum, NumberVector vec) {
        for (int d = 0; d < sum.length; ++d) {
            int n = d;
            sum[n] = sum[n] - vec.doubleValue(d);
        }
    }

    public static void plusMinusEquals(double[] add, double[] sub, NumberVector vec) {
        if (vec instanceof SparseNumberVector) {
            AbstractKMeans.sparsePlusMinusEquals(add, sub, (SparseNumberVector)vec);
        } else {
            AbstractKMeans.densePlusMinusEquals(add, sub, vec);
        }
    }

    private static void densePlusMinusEquals(double[] add, double[] sub, NumberVector vec) {
        int d = 0;
        while (d < add.length) {
            double v = vec.doubleValue(d);
            int n = d;
            add[n] = add[n] + v;
            int n2 = d++;
            sub[n2] = sub[n2] - v;
        }
    }

    private static void sparsePlusMinusEquals(double[] add, double[] sub, SparseNumberVector vec) {
        int j = vec.iter();
        while (vec.iterValid(j)) {
            int d;
            double v = vec.iterDoubleValue(j);
            int n = d = vec.iterDim(j);
            add[n] = add[n] + v;
            int n2 = d;
            sub[n2] = sub[n2] - v;
            j = vec.iterAdvance(j);
        }
    }

    private static double[][] sparseMeans(List<? extends DBIDs> clusters, double[][] means, Relation<? extends SparseNumberVector> relation) {
        int k = means.length;
        double[][] newMeans = new double[k][];
        for (int i = 0; i < k; ++i) {
            DBIDs list = clusters.get(i);
            if (list.isEmpty()) {
                newMeans[i] = means[i];
                continue;
            }
            double[] mean = new double[means[i].length];
            DBIDIter iter = list.iter();
            while (iter.valid()) {
                AbstractKMeans.sparsePlusEquals(mean, (SparseNumberVector)relation.get((DBIDRef)iter));
                iter.advance();
            }
            newMeans[i] = VMath.timesEquals((double[])mean, (double)(1.0 / (double)list.size()));
        }
        return newMeans;
    }

    protected static void nearestMeans(double[][] cdist, int[][] cnum) {
        int k = cdist.length;
        double[] buf = new double[k - 1];
        for (int i = 0; i < k; ++i) {
            System.arraycopy(cdist[i], 0, buf, 0, i);
            System.arraycopy(cdist[i], i + 1, buf, i, k - i - 1);
            for (int j = 0; j < buf.length; ++j) {
                cnum[i][j] = j < i ? j : j + 1;
            }
            DoubleIntegerArrayQuickSort.sort((double[])buf, (int[])cnum[i], (int)(k - 1));
        }
    }

    protected static void incrementalUpdateMean(double[] mean, NumberVector vec, int newsize, double op) {
        if (newsize == 0) {
            return;
        }
        VMath.plusTimesEquals((double[])mean, (double[])VMath.minusEquals((double[])vec.toArray(), (double[])mean), (double)(op / (double)newsize));
    }

    @Override
    public void setK(int k) {
        this.k = k;
    }

    @Override
    public NumberVectorDistance<? super V> getDistance() {
        return this.distance;
    }

    @Override
    public void setDistance(NumberVectorDistance<? super V> distance) {
        this.distance = distance;
    }

    @Override
    public void setInitializer(KMeansInitialization init) {
        this.initializer = init;
    }

    protected abstract Logging getLogger();

    public static abstract class Par<V extends NumberVector>
    implements Parameterizer {
        protected int k;
        protected int maxiter;
        protected KMeansInitialization initializer;
        protected boolean varstat = false;
        protected NumberVectorDistance<? super V> distance;

        public void configure(Parameterization config) {
            this.getParameterK(config);
            this.getParameterInitialization(config);
            this.getParameterDistance(config);
            this.getParameterMaxIter(config);
        }

        protected void getParameterK(Parameterization config) {
            ((IntParameter)new IntParameter(KMeans.K_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_INT)).grab(config, x -> {
                this.k = x;
            });
        }

        protected void getParameterDistance(Parameterization config) {
            new ObjectParameter(Algorithm.Utils.DISTANCE_FUNCTION_ID, PrimitiveDistance.class, SquaredEuclideanDistance.class).grab(config, x -> {
                this.distance = x;
                if (x instanceof SquaredEuclideanDistance || x instanceof EuclideanDistance || x instanceof CosineDistance) {
                    return;
                }
                if (this.needsMetric() && !x.isMetric()) {
                    Logging.getLogger(this.getClass()).warning((CharSequence)"This k-means variants requires the triangle inequality, and thus should only be used with squared Euclidean distance!");
                } else {
                    Logging.getLogger(this.getClass()).warning((CharSequence)"k-means optimizes the sum of squares - it should be used with squared euclidean distance and may stop converging otherwise!");
                }
            });
        }

        protected boolean needsMetric() {
            return false;
        }

        protected void getParameterInitialization(Parameterization config) {
            new ObjectParameter(KMeans.INIT_ID, KMeansInitialization.class, RandomlyChosen.class).grab(config, x -> {
                this.initializer = x;
            });
        }

        protected void getParameterMaxIter(Parameterization config) {
            ((IntParameter)new IntParameter(KMeans.MAXITER_ID, 0).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_INT)).grab(config, x -> {
                this.maxiter = x;
            });
        }

        protected void getParameterVarstat(Parameterization config) {
            new Flag(KMeans.VARSTAT_ID).grab(config, x -> {
                this.varstat = x;
            });
        }

        public abstract AbstractKMeans<V, ?> make();
    }

    public static abstract class Instance {
        protected double[][] means;
        protected List<ModifiableDBIDs> clusters;
        protected WritableIntegerDataStore assignment;
        protected double[] varsum;
        protected Relation<? extends NumberVector> relation;
        protected long diststat = 0L;
        private final NumberVectorDistance<?> df;
        protected final int k;
        protected final boolean isSquared;
        protected String key;

        public Instance(Relation<? extends NumberVector> relation, NumberVectorDistance<?> df, double[][] means) {
            this.relation = relation;
            this.df = df;
            this.isSquared = df.isSquared();
            this.means = means;
            this.k = means.length;
            int guessedsize = (int)((double)relation.size() * 2.0 / (double)this.k);
            this.clusters = new ArrayList<ModifiableDBIDs>(this.k);
            for (int i = 0; i < this.k; ++i) {
                this.clusters.add((ModifiableDBIDs)DBIDUtil.newHashSet((int)guessedsize));
            }
            this.assignment = DataStoreUtil.makeIntegerStorage((DBIDs)relation.getDBIDs(), (int)3, (int)-1);
            this.varsum = new double[this.k];
            this.key = this.getClass().getName().replace("$Instance", "");
        }

        protected double distance(NumberVector x, NumberVector y) {
            ++this.diststat;
            return this.df.distance(x, y);
        }

        protected double distance(NumberVector x, double[] y) {
            ++this.diststat;
            if (this.df.getClass() == SquaredEuclideanDistance.class) {
                if (y.length != x.getDimensionality()) {
                    throw new IllegalArgumentException("Objects do not have the same dimensionality.");
                }
                double v = 0.0;
                for (int i = 0; i < y.length; ++i) {
                    double d = x.doubleValue(i) - y[i];
                    v += d * d;
                }
                return v;
            }
            return this.df.distance(x, (NumberVector)DoubleVector.wrap((double[])y));
        }

        protected double distance(double[] x, double[] y) {
            ++this.diststat;
            if (this.df.getClass() == SquaredEuclideanDistance.class) {
                if (y.length != x.length) {
                    throw new IllegalArgumentException("Objects do not have the same dimensionality.");
                }
                double v = 0.0;
                for (int i = 0; i < x.length; ++i) {
                    double d = x[i] - y[i];
                    v += d * d;
                }
                return v;
            }
            return this.df.distance((NumberVector)DoubleVector.wrap((double[])x), (NumberVector)DoubleVector.wrap((double[])y));
        }

        protected double sqrtdistance(NumberVector x, NumberVector y) {
            double d = this.distance(x, y);
            return this.isSquared ? Math.sqrt(d) : d;
        }

        protected double sqrtdistance(NumberVector x, double[] y) {
            double d = this.distance(x, y);
            return this.isSquared ? Math.sqrt(d) : d;
        }

        protected double sqrtdistance(double[] x, double[] y) {
            double d = this.distance(x, y);
            return this.isSquared ? Math.sqrt(d) : d;
        }

        public void run(int maxiter) {
            Logging log = this.getLogger();
            IndefiniteProgress prog = log.isVerbose() ? new IndefiniteProgress("Iteration") : null;
            int iteration = 0;
            while (++iteration <= maxiter) {
                Duration duration = log.newDuration(this.key + "." + iteration + ".time").begin();
                long prevdiststat = this.diststat;
                log.incrementProcessed((AbstractProgress)prog);
                int changed = this.iterate(iteration);
                if (log.isStatistics()) {
                    double s;
                    log.statistics((Statistic)duration.end());
                    log.statistics((Statistic)new LongStatistic(this.key + "." + iteration + ".reassignments", (long)Math.abs(changed)));
                    if (this.diststat > prevdiststat) {
                        log.statistics((Statistic)new LongStatistic(this.key + "." + iteration + ".distance-computations", this.diststat - prevdiststat));
                    }
                    if ((s = VMath.sum((double[])this.varsum)) > 0.0) {
                        log.statistics((Statistic)new DoubleStatistic(this.key + "." + iteration + ".variance-sum", s));
                    }
                }
                if (changed > 0) continue;
                break;
            }
            log.setCompleted(prog);
            log.statistics((Statistic)new LongStatistic(this.key + ".iterations", (long)iteration));
        }

        protected abstract int iterate(int var1);

        protected void meansFromSums(double[][] dst, double[][] sums, double[][] prev) {
            for (int i = 0; i < this.k; ++i) {
                int size = this.clusters.get(i).size();
                if (size == 0) {
                    System.arraycopy(prev[i], 0, dst[i], 0, prev[i].length);
                    continue;
                }
                VMath.overwriteTimes((double[])dst[i], (double[])sums[i], (double)(1.0 / (double)size));
            }
        }

        protected void copyMeans(double[][] src, double[][] dst) {
            for (int i = 0; i < this.k; ++i) {
                double[] srci = src[i];
                double[] dsti = dst[i];
                System.arraycopy(srci, 0, dsti, 0, srci.length);
                if (srci.length >= dsti.length) continue;
                Arrays.fill(dsti, srci.length, dsti.length, 0.0);
            }
        }

        protected int assignToNearestCluster() {
            assert (this.k == this.means.length);
            int changed = 0;
            Arrays.fill(this.varsum, 0.0);
            for (ModifiableDBIDs cluster : this.clusters) {
                cluster.clear();
            }
            DBIDIter iditer = this.relation.iterDBIDs();
            while (iditer.valid()) {
                NumberVector fv = (NumberVector)this.relation.get((DBIDRef)iditer);
                double mindist = this.distance(fv, this.means[0]);
                int minIndex = 0;
                for (int i = 1; i < this.k; ++i) {
                    double dist = this.distance(fv, this.means[i]);
                    if (!(dist < mindist)) continue;
                    minIndex = i;
                    mindist = dist;
                }
                int n = minIndex;
                this.varsum[n] = this.varsum[n] + (this.isSquared ? mindist : mindist * mindist);
                this.clusters.get(minIndex).add((DBIDRef)iditer);
                if (this.assignment.putInt((DBIDRef)iditer, minIndex) != minIndex) {
                    ++changed;
                }
                iditer.advance();
            }
            return changed;
        }

        protected void recomputeSeperation(double[] sep, double[][] cdist) {
            int k = this.means.length;
            assert (sep.length == k);
            Arrays.fill(sep, Double.POSITIVE_INFINITY);
            for (int i = 1; i < k; ++i) {
                double[] mi = this.means[i];
                for (int j = 0; j < i; ++j) {
                    double halfd;
                    double d = halfd = 0.5 * this.sqrtdistance(mi, this.means[j]);
                    cdist[j][i] = d;
                    cdist[i][j] = d;
                    sep[i] = halfd < sep[i] ? halfd : sep[i];
                    sep[j] = halfd < sep[j] ? halfd : sep[j];
                }
            }
        }

        protected void initialSeperation(double[][] cdist) {
            int k = this.means.length;
            for (int i = 1; i < k; ++i) {
                double[] mi = this.means[i];
                for (int j = 0; j < i; ++j) {
                    double d = 0.5 * this.sqrtdistance(mi, this.means[j]);
                    cdist[j][i] = d;
                    cdist[i][j] = d;
                }
            }
        }

        protected void computeSquaredSeparation(double[][] cost) {
            for (int i = 0; i < this.k; ++i) {
                double[] mi = this.means[i];
                for (int j = 0; j < i; ++j) {
                    double d = this.distance(mi, this.means[j]) * 0.25;
                    cost[j][i] = d;
                    cost[i][j] = d;
                }
            }
        }

        protected void movedDistance(double[][] means, double[][] newmeans, double[] dists) {
            assert (newmeans.length == means.length && dists.length == means.length);
            for (int i = 0; i < means.length; ++i) {
                dists[i] = this.sqrtdistance(means[i], newmeans[i]);
            }
        }

        public Clustering<KMeansModel> buildResult() {
            Clustering<KMeansModel> result = new Clustering<KMeansModel>();
            Metadata.of(result).setLongName("k-Means Clustering");
            for (int i = 0; i < this.clusters.size(); ++i) {
                DBIDs ids = (DBIDs)this.clusters.get(i);
                if (ids.isEmpty()) {
                    this.getLogger().warning((CharSequence)"K-Means produced an empty cluster - bad initialization?");
                }
                result.addToplevelCluster(new Cluster<KMeansModel>(ids, new KMeansModel(this.means[i], this.varsum != null ? this.varsum[i] : Double.NaN)));
            }
            return result;
        }

        public Clustering<KMeansModel> buildResult(boolean varstat, Relation<? extends NumberVector> relation) {
            Logging log = this.getLogger();
            if (varstat) {
                long beforestat = this.diststat;
                log.statistics((Statistic)new LongStatistic(this.key + ".distance-computations.main", this.diststat));
                this.recomputeVariance(relation);
                log.statistics((Statistic)new DoubleStatistic(this.key + ".variance-sum", VMath.sum((double[])this.varsum)));
                log.statistics((Statistic)new LongStatistic(this.key + ".variance.distance-computations", this.diststat - beforestat));
            } else {
                Arrays.fill(this.varsum, Double.NaN);
            }
            Clustering<KMeansModel> result = this.buildResult();
            log.statistics((Statistic)new LongStatistic(this.key + ".distance-computations", this.diststat));
            return result;
        }

        protected void recomputeVariance(Relation<? extends NumberVector> relation) {
            Arrays.fill(this.varsum, 0.0);
            for (int i = 0; i < this.clusters.size(); ++i) {
                double[] mean = this.means[i];
                double vsum = 0.0;
                DBIDMIter it = this.clusters.get(i).iter();
                while (it.valid()) {
                    vsum += this.distance((NumberVector)relation.get((DBIDRef)it), mean);
                    it.advance();
                }
                this.varsum[i] = vsum;
            }
        }

        protected abstract Logging getLogger();
    }
}

