/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.kmeans;

import elki.clustering.kmeans.AbstractKMeans;
import elki.clustering.kmeans.initialization.KMeansInitialization;
import elki.data.Clustering;
import elki.data.NumberVector;
import elki.data.VectorUtil;
import elki.data.model.KMeansModel;
import elki.database.ids.ArrayModifiableDBIDs;
import elki.database.ids.DBIDArrayIter;
import elki.database.ids.DBIDArrayMIter;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.ids.ModifiableDBIDs;
import elki.database.ids.QuickSelectDBIDs;
import elki.database.relation.Relation;
import elki.distance.NumberVectorDistance;
import elki.logging.Logging;
import elki.logging.statistics.Duration;
import elki.logging.statistics.Statistic;
import elki.logging.statistics.StringStatistic;
import elki.math.MathUtil;
import elki.math.linearalgebra.VMath;
import elki.utilities.documentation.Reference;
import elki.utilities.documentation.References;
import elki.utilities.documentation.Title;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.constraints.ParameterConstraint;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.EnumParameter;
import elki.utilities.optionhandling.parameters.IntParameter;
import java.util.Arrays;
import java.util.Comparator;

@Title(value="K-d-tree K-means with Pruning")
@References(value={@Reference(authors="K. Alsabti, S. Ranka, V. Singh", title="An efficient k-means clustering algorithm", booktitle="Electrical Engineering and Computer Science, Technical Report 43", url="https://surface.syr.edu/eecs/43/", bibkey="tr/syracuse/AlsabtiRS97"), @Reference(authors="K. Alsabti, S. Ranka, V. Singh", title="An Efficient Space-Partitioning Based Algorithm for the K-Means Clustering", booktitle="Pacific-Asia Conference on Knowledge Discovery and Data Mining", url="https://doi.org/10.1007/3-540-48912-6_47", bibkey="DBLP:conf/pakdd/AlsabtiRS99")})
public class KDTreePruningKMeans<V extends NumberVector>
extends AbstractKMeans<V, KMeansModel> {
    private static final Logging LOG = Logging.getLogger(KDTreePruningKMeans.class);
    protected Split split = Split.MIDPOINT;
    protected int leafsize;

    public KDTreePruningKMeans(NumberVectorDistance<? super V> distance, int k, int maxiter, KMeansInitialization initializer, Split split, int leafsize) {
        super(distance, k, maxiter, initializer);
        this.split = split;
        this.leafsize = leafsize;
    }

    @Override
    public Clustering<KMeansModel> run(Relation<V> relation) {
        Instance instance = new Instance(relation, this.distance, this.initialMeans(relation));
        instance.run(this.maxiter);
        return instance.buildResult();
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Par<V extends NumberVector>
    extends AbstractKMeans.Par<V> {
        public static final OptionID SPLIT_ID = new OptionID("kmeans.kdtree.split", "Splitting strategy to use (midpoint or median).");
        public static final OptionID LEAFSIZE_ID = new OptionID("kmeans.kdtree.leafsize", "Leaf size of the k-d-tree.");
        protected Split split = Split.MIDPOINT;
        protected int leafsize;

        @Override
        public void configure(Parameterization config) {
            super.configure(config);
            new EnumParameter(SPLIT_ID, Split.class, (Enum)Split.MIDPOINT).grab(config, x -> {
                this.split = x;
            });
            ((IntParameter)new IntParameter(LEAFSIZE_ID, 5).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_INT)).grab(config, x -> {
                this.leafsize = x;
            });
        }

        @Override
        public KDTreePruningKMeans<V> make() {
            return new KDTreePruningKMeans(this.distance, this.k, this.maxiter, this.initializer, this.split, this.leafsize);
        }
    }

    public static class KDNode {
        double[] sum;
        double[] mid;
        double[] halfwidth;
        KDNode leftChild;
        KDNode rightChild;
        int start;
        int end;

        public KDNode(Relation<? extends NumberVector> relation, DBIDArrayIter iter, int start, int end) {
            this.start = start;
            this.end = end;
            iter.seek(start);
            double[] min = ((NumberVector)relation.get((DBIDRef)iter)).toArray();
            double[] max = (double[])min.clone();
            this.sum = (double[])min.clone();
            double[] sum = this.sum;
            int dim = min.length;
            iter.advance();
            while (iter.getOffset() < end) {
                NumberVector currentVector = (NumberVector)relation.get((DBIDRef)iter);
                for (int i = 0; i < dim; ++i) {
                    double v = currentVector.doubleValue(i);
                    int n = i;
                    sum[n] = sum[n] + v;
                    if (v > max[i]) {
                        max[i] = v;
                        continue;
                    }
                    if (!(v < min[i])) continue;
                    min[i] = v;
                }
                iter.advance();
            }
            for (int i = 0; i < dim; ++i) {
                double mi = min[i];
                double ma = max[i];
                min[i] = 0.5 * (ma + mi);
                max[i] = 0.5 * (ma - mi);
            }
            this.mid = min;
            this.halfwidth = max;
        }
    }

    protected class Instance
    extends AbstractKMeans.Instance {
        protected KDNode root;
        protected ArrayModifiableDBIDs sorted;
        protected DBIDArrayMIter iter;
        protected int[] indices;
        protected double[][] clusterSums;
        protected int[] clusterSizes;

        public Instance(Relation<? extends NumberVector> relation, NumberVectorDistance<?> df, double[][] means) {
            super(relation, df, means);
        }

        @Override
        public void run(int maxiter) {
            String prefix = KDTreePruningKMeans.this.getClass().getName();
            Duration construction = LOG.newDuration(prefix + ".k-d-tree-construction").begin();
            this.sorted = DBIDUtil.newArray((DBIDs)this.relation.getDBIDs());
            this.iter = this.sorted.iter();
            LOG.statistics((Statistic)new StringStatistic(prefix + ".k-d-tree-split", KDTreePruningKMeans.this.split.toString()));
            switch (KDTreePruningKMeans.this.split) {
                case MIDPOINT: {
                    this.root = this.buildTreeMidpoint((Relation<? extends NumberVector>)this.relation, 0, this.sorted.size());
                    break;
                }
                case BOUNDED_MIDPOINT: {
                    this.root = this.buildTreeBoundedMidpoint((Relation<? extends NumberVector>)this.relation, 0, this.sorted.size(), new VectorUtil.SortDBIDsBySingleDimension(this.relation));
                    break;
                }
                case MEDIAN: {
                    this.root = this.buildTreeMedian((Relation<? extends NumberVector>)this.relation, 0, this.sorted.size(), new VectorUtil.SortDBIDsBySingleDimension(this.relation));
                    break;
                }
                case SSQ: {
                    this.root = this.buildTreeSSQ((Relation<? extends NumberVector>)this.relation, 0, this.sorted.size(), new VectorUtil.SortDBIDsBySingleDimension(this.relation));
                }
            }
            LOG.statistics((Statistic)construction.end());
            this.indices = MathUtil.sequence((int)0, (int)this.k);
            this.clusterSizes = new int[this.k];
            super.run(maxiter);
        }

        protected KDNode buildTreeMidpoint(Relation<? extends NumberVector> relation, int left, int right) {
            KDNode node = new KDNode(relation, (DBIDArrayIter)this.iter, left, right);
            if (right - left <= KDTreePruningKMeans.this.leafsize) {
                return node;
            }
            int dim = VMath.argmax((double[])node.halfwidth);
            double mid = node.mid[dim];
            int l = left;
            int r = right - 1;
            while (true) {
                if (l <= r && ((NumberVector)relation.get((DBIDRef)this.iter.seek(l))).doubleValue(dim) <= mid) {
                    ++l;
                    continue;
                }
                while (l <= r && ((NumberVector)relation.get((DBIDRef)this.iter.seek(r))).doubleValue(dim) >= mid) {
                    --r;
                }
                if (l >= r) break;
                this.sorted.swap(l++, r--);
            }
            assert (((NumberVector)relation.get((DBIDRef)this.iter.seek(r))).doubleValue(dim) <= mid) : ((NumberVector)relation.get((DBIDRef)this.iter.seek(r))).doubleValue(dim) + " not less than " + mid;
            if (++r == right) {
                return node;
            }
            node.leftChild = this.buildTreeMidpoint(relation, left, r);
            node.rightChild = this.buildTreeMidpoint(relation, r, right);
            return node;
        }

        protected KDNode buildTreeBoundedMidpoint(Relation<? extends NumberVector> relation, int left, int right, VectorUtil.SortDBIDsBySingleDimension comp) {
            KDNode node = new KDNode(relation, (DBIDArrayIter)this.iter, left, right);
            if (right - left <= KDTreePruningKMeans.this.leafsize) {
                return node;
            }
            int dim = VMath.argmax((double[])node.halfwidth);
            double mid = node.mid[dim];
            int l = left;
            int r = right - 1;
            while (true) {
                if (l <= r && ((NumberVector)relation.get((DBIDRef)this.iter.seek(l))).doubleValue(dim) <= mid) {
                    ++l;
                    continue;
                }
                while (l <= r && ((NumberVector)relation.get((DBIDRef)this.iter.seek(r))).doubleValue(dim) >= mid) {
                    --r;
                }
                if (l >= r) break;
                this.sorted.swap(l++, r--);
            }
            assert (((NumberVector)relation.get((DBIDRef)this.iter.seek(r))).doubleValue(dim) <= mid) : ((NumberVector)relation.get((DBIDRef)this.iter.seek(r))).doubleValue(dim) + " not less than " + mid;
            if (++r == right) {
                return node;
            }
            int q = right - left >>> 3;
            if (left + q > r) {
                comp.setDimension(dim);
                int n = r;
                r = left + q;
                QuickSelectDBIDs.quickSelect((ArrayModifiableDBIDs)this.sorted, (Comparator)comp, (int)n, (int)right, (int)r);
            } else if (right - q < r) {
                comp.setDimension(dim);
                int n = r;
                r = right - q;
                QuickSelectDBIDs.quickSelect((ArrayModifiableDBIDs)this.sorted, (Comparator)comp, (int)left, (int)n, (int)r);
            }
            assert (left < r && r < right) : "Useless split selected: " + left + " < " + r + " < " + right;
            node.leftChild = this.buildTreeBoundedMidpoint(relation, left, r, comp);
            node.rightChild = this.buildTreeBoundedMidpoint(relation, r, right, comp);
            return node;
        }

        protected KDNode buildTreeMedian(Relation<? extends NumberVector> relation, int left, int right, VectorUtil.SortDBIDsBySingleDimension comp) {
            KDNode node = new KDNode(relation, (DBIDArrayIter)this.iter, left, right);
            if (right - left <= KDTreePruningKMeans.this.leafsize) {
                return node;
            }
            int middle = left + right >>> 1;
            int sdim = VMath.argmax((double[])node.halfwidth);
            if (node.halfwidth[sdim] > 0.0) {
                comp.setDimension(sdim);
                QuickSelectDBIDs.quickSelect((ArrayModifiableDBIDs)this.sorted, (Comparator)comp, (int)left, (int)right, (int)middle);
                node.leftChild = this.buildTreeMedian(relation, left, middle, comp);
                node.rightChild = this.buildTreeMedian(relation, middle, right, comp);
            }
            return node;
        }

        protected KDNode buildTreeSSQ(Relation<? extends NumberVector> relation, int left, int right, VectorUtil.SortDBIDsBySingleDimension comp) {
            KDNode node = new KDNode(relation, (DBIDArrayIter)this.iter, left, right);
            int len = right - left;
            if (len <= KDTreePruningKMeans.this.leafsize) {
                return node;
            }
            int dims = node.sum.length;
            int bestdim = 0;
            int bestpos = len >>> 1;
            double bestscore = Double.NEGATIVE_INFINITY;
            for (int dim = 0; dim < dims; ++dim) {
                comp.setDimension(dim);
                this.sorted.sort(left, right, (Comparator)comp);
                int i = 1;
                double[] s1 = new double[dims];
                double[] s2 = (double[])node.sum.clone();
                this.iter.seek(left);
                for (int j = len - 1; j > 1; --j) {
                    NumberVector vec = (NumberVector)relation.get((DBIDRef)this.iter);
                    AbstractKMeans.plusEquals(s1, vec);
                    AbstractKMeans.minusEquals(s2, vec);
                    double score = 0.0;
                    for (int d = 0; d < dims; ++d) {
                        double v = s1[d] / (double)i - s2[d] / (double)j;
                        score += v * v;
                    }
                    double s = score * (double)i * (double)j;
                    if (s > bestscore) {
                        bestscore = s;
                        bestdim = dim;
                        bestpos = i;
                    }
                    this.iter.advance();
                    ++i;
                }
            }
            if (bestscore == 0.0) {
                return node;
            }
            comp.setDimension(bestdim);
            QuickSelectDBIDs.quickSelect((ArrayModifiableDBIDs)this.sorted, (Comparator)comp, (int)left, (int)right, (int)(bestpos += left));
            node.leftChild = this.buildTreeSSQ(relation, left, bestpos, comp);
            node.rightChild = this.buildTreeSSQ(relation, bestpos, right, comp);
            return node;
        }

        @Override
        protected int iterate(int iteration) {
            this.clusterSums = new double[this.means.length][this.means[0].length];
            Arrays.fill(this.clusterSizes, 0);
            int changed = this.traversal(this.root, this.indices.length);
            for (int k = 0; k < this.clusterSums.length; ++k) {
                if (this.clusterSizes[k] <= 0) continue;
                this.means[k] = VMath.timesEquals((double[])this.clusterSums[k], (double)(1.0 / (double)this.clusterSizes[k]));
            }
            if (changed == 0 || KDTreePruningKMeans.this.maxiter <= iteration) {
                for (ModifiableDBIDs cluster : this.clusters) {
                    cluster.clear();
                }
                DBIDIter it = this.relation.iterDBIDs();
                while (it.valid()) {
                    ((ModifiableDBIDs)this.clusters.get(this.assignment.intValue((DBIDRef)it))).add((DBIDRef)it);
                    it.advance();
                }
            }
            return changed;
        }

        protected int traversal(KDNode u, int alive) {
            if ((alive = this.pruning(u, alive)) == 1) {
                return this.labelSubtree(u.sum, u.start, u.end, this.indices[0]);
            }
            if (u.leftChild == null) {
                assert (u.rightChild == null);
                return this.traverseLeaf(u.start, u.end, alive);
            }
            assert (u.rightChild != null);
            return this.traversal(u.leftChild, alive) + this.traversal(u.rightChild, alive);
        }

        protected int labelSubtree(double[] sum, int start, int end, int index) {
            VMath.plusEquals((double[])this.clusterSums[index], (double[])sum);
            int n = index;
            this.clusterSizes[n] = this.clusterSizes[n] + (end - start);
            int changed = 0;
            this.iter.seek(start);
            while (this.iter.getOffset() < end) {
                int prev = this.assignment.putInt((DBIDRef)this.iter, index);
                if (prev != index) {
                    ++changed;
                }
                this.iter.advance();
            }
            return changed;
        }

        protected int pruning(KDNode u, int alive) {
            double[] mid = u.mid;
            double[] halfwidth = u.halfwidth;
            double minmaxdist = this.getMinMaxDist(mid, halfwidth, alive);
            int i = 0;
            while (i < alive) {
                if (this.mindist(this.means[this.indices[i]], mid, halfwidth) > minmaxdist) {
                    int swap = this.indices[i];
                    this.indices[i] = this.indices[--alive];
                    this.indices[alive] = swap;
                    continue;
                }
                ++i;
            }
            return alive;
        }

        protected double getMinMaxDist(double[] mid, double[] halfwidth, int alive) {
            ++this.diststat;
            int best = 0;
            double bestDistance = Double.POSITIVE_INFINITY;
            for (int i = 0; i < alive; ++i) {
                double[] mean = this.means[this.indices[i]];
                double maxdist = 0.0;
                for (int d = 0; d < mean.length; ++d) {
                    double a = mean[d];
                    double b = mid[d];
                    double delta = (a > b ? a - b : b - a) + halfwidth[d];
                    maxdist += delta * delta;
                }
                if (!(maxdist < bestDistance)) continue;
                best = i;
                bestDistance = maxdist;
            }
            int swap = this.indices[best];
            this.indices[best] = this.indices[alive - 1];
            this.indices[alive - 1] = swap;
            return bestDistance;
        }

        protected double mindist(double[] mean, double[] mid, double[] halfwidth) {
            ++this.diststat;
            double mindist = 0.0;
            for (int d = 0; d < mean.length; ++d) {
                double a = mean[d];
                double b = mid[d];
                double delta = (a > b ? a - b : b - a) - halfwidth[d];
                if (!(delta > 0.0)) continue;
                mindist += delta * delta;
            }
            return mindist;
        }

        protected int traverseLeaf(int start, int end, int alive) {
            int changed = 0;
            this.iter.seek(start);
            while (this.iter.getOffset() < end) {
                int centerIndex = this.indices[0];
                double currentDistance = Double.POSITIVE_INFINITY;
                NumberVector fv = (NumberVector)this.relation.get((DBIDRef)this.iter);
                for (int i = 0; i < alive; ++i) {
                    double distance = this.distance(fv, this.means[this.indices[i]]);
                    if (!(distance < currentDistance)) continue;
                    centerIndex = this.indices[i];
                    currentDistance = distance;
                }
                int n = centerIndex;
                this.clusterSizes[n] = this.clusterSizes[n] + 1;
                AbstractKMeans.plusEquals(this.clusterSums[centerIndex], fv);
                if (this.assignment.putInt((DBIDRef)this.iter, centerIndex) != centerIndex) {
                    ++changed;
                }
                this.iter.advance();
            }
            return changed;
        }

        @Override
        protected Logging getLogger() {
            return LOG;
        }
    }

    public static enum Split {
        MIDPOINT,
        BOUNDED_MIDPOINT,
        MEDIAN,
        SSQ;

    }
}

