001/*
002 * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.clustering.kmeans;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.config.PropertyException;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import com.oracle.labs.mlrg.olcut.util.MutableLong;
023import com.oracle.labs.mlrg.olcut.util.StreamUtil;
024import org.tribuo.Dataset;
025import org.tribuo.Example;
026import org.tribuo.ImmutableFeatureMap;
027import org.tribuo.ImmutableOutputInfo;
028import org.tribuo.Trainer;
029import org.tribuo.clustering.ClusterID;
030import org.tribuo.clustering.ImmutableClusteringInfo;
031import org.tribuo.math.distance.DistanceType;
032import org.tribuo.math.la.DenseVector;
033import org.tribuo.math.la.SGDVector;
034import org.tribuo.math.la.SparseVector;
035import org.tribuo.provenance.ModelProvenance;
036import org.tribuo.provenance.TrainerProvenance;
037import org.tribuo.provenance.impl.TrainerProvenanceImpl;
038import org.tribuo.util.Util;
039
040import java.security.AccessController;
041import java.security.PrivilegedAction;
042import java.time.OffsetDateTime;
043import java.util.ArrayList;
044import java.util.Arrays;
045import java.util.Collections;
046import java.util.HashMap;
047import java.util.List;
048import java.util.Map;
049import java.util.Map.Entry;
050import java.util.SplittableRandom;
051import java.util.concurrent.ExecutionException;
052import java.util.concurrent.ForkJoinPool;
053import java.util.concurrent.ForkJoinWorkerThread;
054import java.util.concurrent.atomic.AtomicInteger;
055import java.util.function.Consumer;
056import java.util.logging.Level;
057import java.util.logging.Logger;
058import java.util.stream.IntStream;
059import java.util.stream.Stream;
060
061/**
062 * A K-Means trainer, which generates a K-means clustering of the supplied
063 * data. The model finds the centres, and then predict needs to be
064 * called to infer the centre assignments for the input data.
065 * <p>
066 * It's slightly contorted to fit the Tribuo Trainer and Model API, as the cluster assignments
067 * can only be retrieved from the model after training, and require re-evaluating each example.
068 * <p>
069 * The Trainer has a parameterised distance function, and a selectable number
070 * of threads used in the training step. The thread pool is local to an invocation of train,
071 * so there can be multiple concurrent trainings.
072 * <p>
073 * The train method will instantiate dense examples as dense vectors, speeding up the computation.
074 * <p>
075 * Note parallel training uses a {@link ForkJoinPool} which requires that the Tribuo codebase
076 * is given the "modifyThread" and "modifyThreadGroup" privileges when running under a
077 * {@link java.lang.SecurityManager}.
078 * <p>
079 * See:
080 * <pre>
081 * J. Friedman, T. Hastie, &amp; R. Tibshirani.
082 * "The Elements of Statistical Learning"
083 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a>
084 * </pre>
085 * <p>
086 * For more on optional kmeans++ initialisation, see:
087 * <pre>
088 * D. Arthur, S. Vassilvitskii.
089 * "K-Means++: The Advantages of Careful Seeding"
090 * <a href="https://theory.stanford.edu/~sergei/papers/kMeansPP-soda">PDF</a>
091 * </pre>
092 */
093public class KMeansTrainer implements Trainer<ClusterID> {
094    private static final Logger logger = Logger.getLogger(KMeansTrainer.class.getName());
095
096    // Thread factory for the FJP, to allow use with OpenSearch's SecureSM
097    private static final CustomForkJoinWorkerThreadFactory THREAD_FACTORY = new CustomForkJoinWorkerThreadFactory();
098
099    /**
100     * Possible distance functions.
101     * @deprecated
102     * This Enum is deprecated in version 4.3, replaced by {@link DistanceType}
103     */
104    @Deprecated
105    public enum Distance {
106        /**
107         * Euclidean (or l2) distance.
108         */
109        EUCLIDEAN(DistanceType.L2),
110        /**
111         * Cosine similarity as a distance measure.
112         */
113        COSINE(DistanceType.COSINE),
114        /**
115         * L1 (or Manhattan) distance.
116         */
117        L1(DistanceType.L1);
118
119        private final DistanceType distanceType;
120
121        Distance(DistanceType distanceType) {
122            this.distanceType = distanceType;
123        }
124
125        /**
126         * Returns the {@link DistanceType} mapping for the enumeration's value.
127         *
128         * @return distanceType The {@link DistanceType} value.
129         */
130        public DistanceType getDistanceType() {
131            return distanceType;
132        }
133    }
134
135    /**
136     * Possible initialization functions.
137     */
138    public enum Initialisation {
139        /**
140         * Initialize centroids by choosing uniformly at random from the data
141         * points.
142         */
143        RANDOM,
144        /**
145         * KMeans++ initialisation.
146         */
147        PLUSPLUS
148    }
149
150    @Config(mandatory = true, description = "Number of centroids (i.e., the \"k\" in k-means).")
151    private int centroids;
152
153    @Config(mandatory = true, description = "The number of iterations to run.")
154    private int iterations;
155
156    @Deprecated
157    @Config(description = "The distance function to use. This is now deprecated.")
158    private Distance distanceType;
159
160    @Config(description = "The distance function to use.")
161    private org.tribuo.math.distance.Distance dist;
162
163    @Config(description = "The centroid initialisation method to use.")
164    private Initialisation initialisationType = Initialisation.RANDOM;
165
166    @Config(description = "The number of threads to use for training.")
167    private int numThreads = 1;
168
169    @Config(mandatory = true, description = "The seed to use for the RNG.")
170    private long seed;
171
172    private SplittableRandom rng;
173
174    private int trainInvocationCounter;
175
176    /**
177     * for olcut.
178     */
179    private KMeansTrainer() { }
180
181    /**
182     * Constructs a K-Means trainer using the supplied parameters and the default random initialisation.
183     * @deprecated
184     * This Constructor is deprecated in version 4.3.
185     *
186     * @param centroids The number of centroids to use.
187     * @param iterations The maximum number of iterations.
188     * @param distanceType The distance function.
189     * @param numThreads The number of threads.
190     * @param seed The random seed.
191     */
192    @Deprecated
193    public KMeansTrainer(int centroids, int iterations, Distance distanceType, int numThreads, long seed) {
194        this(centroids,iterations,distanceType,Initialisation.RANDOM,numThreads,seed);
195    }
196
197    /**
198     * Constructs a K-Means trainer using the supplied parameters and the default random initialisation.
199     *
200     * @param centroids The number of centroids to use.
201     * @param iterations The maximum number of iterations.
202     * @param dist The distance function.
203     * @param numThreads The number of threads.
204     * @param seed The random seed.
205     */
206    public KMeansTrainer(int centroids, int iterations, org.tribuo.math.distance.Distance dist, int numThreads, long seed) {
207        this(centroids,iterations,dist,Initialisation.RANDOM,numThreads,seed);
208    }
209
210    /**
211     * Constructs a K-Means trainer using the supplied parameters.
212     * @deprecated
213     * This Constructor is deprecated in version 4.3.
214     *
215     * @param centroids The number of centroids to use.
216     * @param iterations The maximum number of iterations.
217     * @param distanceType The distance function.
218     * @param initialisationType The centroid initialization method.
219     * @param numThreads The number of threads.
220     * @param seed The random seed.
221     */
222    @Deprecated
223    public KMeansTrainer(int centroids, int iterations, Distance distanceType, Initialisation initialisationType, int numThreads, long seed) {
224        this(centroids, iterations, distanceType.getDistanceType().getDistance(), initialisationType, numThreads, seed);
225    }
226
227    /**
228     * Constructs a K-Means trainer using the supplied parameters.
229     *
230     * @param centroids The number of centroids to use.
231     * @param iterations The maximum number of iterations.
232     * @param dist The distance function.
233     * @param initialisationType The centroid initialization method.
234     * @param numThreads The number of threads.
235     * @param seed The random seed.
236     */
237    public KMeansTrainer(int centroids, int iterations, org.tribuo.math.distance.Distance dist, Initialisation initialisationType, int numThreads, long seed) {
238        this.centroids = centroids;
239        this.iterations = iterations;
240        this.dist = dist;
241        this.initialisationType = initialisationType;
242        this.numThreads = numThreads;
243        this.seed = seed;
244        postConfig();
245    }
246
247    /**
248     * Used by the OLCUT configuration system, and should not be called by external code.
249     */
250    @Override
251    public synchronized void postConfig() {
252        this.rng = new SplittableRandom(seed);
253
254        if (this.distanceType != null) {
255            if (this.dist != null) {
256                throw new PropertyException("dist", "Both dist and distanceType must not both be set.");
257            } else {
258                this.dist = this.distanceType.getDistanceType().getDistance();
259                this.distanceType = null;
260            }
261        }
262
263        if (centroids < 1) {
264            throw new PropertyException("centroids", "Centroids must be positive, found " + centroids);
265        }
266    }
267
268    @Override
269    public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) {
270        return train(examples, runProvenance, INCREMENT_INVOCATION_COUNT);
271    }
272
273    @Override
274    public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance, int invocationCount) {
275        // Creates a new local RNG and adds one to the invocation count.
276        TrainerProvenance trainerProvenance;
277        SplittableRandom localRNG;
278        synchronized (this) {
279            if(invocationCount != INCREMENT_INVOCATION_COUNT) {
280                setInvocationCount(invocationCount);
281            }
282            localRNG = rng.split();
283            trainerProvenance = getProvenance();
284            trainInvocationCounter++;
285        }
286        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
287
288        int[] oldCentre = new int[examples.size()];
289        SGDVector[] data = new SGDVector[examples.size()];
290        double[] weights = new double[examples.size()];
291        int n = 0;
292        for (Example<ClusterID> example : examples) {
293            weights[n] = example.getWeight();
294            if (example.size() == featureMap.size()) {
295                data[n] = DenseVector.createDenseVector(example, featureMap, false);
296            } else {
297                data[n] = SparseVector.createSparseVector(example, featureMap, false);
298            }
299            oldCentre[n] = -1;
300            n++;
301        }
302
303        DenseVector[] centroidVectors;
304        switch (initialisationType) {
305            case RANDOM:
306                centroidVectors = initialiseRandomCentroids(centroids, featureMap, localRNG);
307                break;
308            case PLUSPLUS:
309                centroidVectors = initialisePlusPlusCentroids(centroids, data, localRNG, dist);
310                break;
311            default:
312                throw new IllegalStateException("Unknown initialisation" + initialisationType);
313        }
314
315        Map<Integer, List<Integer>> clusterAssignments = new HashMap<>();
316        boolean parallel = numThreads > 1;
317        for (int i = 0; i < centroids; i++) {
318            clusterAssignments.put(i, parallel ? Collections.synchronizedList(new ArrayList<>()) : new ArrayList<>());
319        }
320
321        AtomicInteger changeCounter = new AtomicInteger(0);
322        Consumer<IntAndVector> eStepFunc = (IntAndVector e) -> {
323            double minDist = Double.POSITIVE_INFINITY;
324            int clusterID = -1;
325            int id = e.idx;
326            SGDVector vector = e.vector;
327            for (int j = 0; j < centroids; j++) {
328                DenseVector cluster = centroidVectors[j];
329                double distance = dist.computeDistance(cluster, vector);
330                if (distance < minDist) {
331                    minDist = distance;
332                    clusterID = j;
333                }
334            }
335
336            clusterAssignments.get(clusterID).add(id);
337            if (oldCentre[id] != clusterID) {
338                // Changed the centroid of this vector.
339                oldCentre[id] = clusterID;
340                changeCounter.incrementAndGet();
341            }
342        };
343
344        boolean converged = false;
345        ForkJoinPool fjp = null;
346        try {
347            if (parallel) {
348                if (System.getSecurityManager() == null) {
349                    fjp = new ForkJoinPool(numThreads);
350                } else {
351                    fjp = new ForkJoinPool(numThreads, THREAD_FACTORY, null, false);
352                }
353            }
354            for (int i = 0; (i < iterations) && !converged; i++) {
355                logger.log(Level.FINE,"Beginning iteration " + i);
356                changeCounter.set(0);
357
358                for (Entry<Integer, List<Integer>> e : clusterAssignments.entrySet()) {
359                    e.getValue().clear();
360                }
361
362                // E step
363                Stream<SGDVector> vecStream = Arrays.stream(data);
364                Stream<Integer> intStream = IntStream.range(0, data.length).boxed();
365                Stream<IntAndVector> zipStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new);
366                if (parallel) {
367                    Stream<IntAndVector> parallelZipStream = StreamUtil.boundParallelism(zipStream.parallel());
368                    try {
369                        fjp.submit(() -> parallelZipStream.forEach(eStepFunc)).get();
370                    } catch (InterruptedException | ExecutionException e) {
371                        throw new RuntimeException("Parallel execution failed", e);
372                    }
373                } else {
374                    zipStream.forEach(eStepFunc);
375                }
376                logger.log(Level.FINE, "E step completed. " + changeCounter.get() + " words updated.");
377
378                mStep(fjp, centroidVectors, clusterAssignments, data, weights);
379
380                logger.log(Level.INFO, "Iteration " + i + " completed. " + changeCounter.get() + " examples updated.");
381
382                if (changeCounter.get() == 0) {
383                    converged = true;
384                    logger.log(Level.INFO, "K-Means converged at iteration " + i);
385                }
386            }
387        } finally {
388            if (fjp != null) {
389                fjp.shutdown();
390            }
391        }
392
393        Map<Integer, MutableLong> counts = new HashMap<>();
394        for (Entry<Integer, List<Integer>> e : clusterAssignments.entrySet()) {
395            counts.put(e.getKey(), new MutableLong(e.getValue().size()));
396        }
397
398        ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);
399
400        ModelProvenance provenance = new ModelProvenance(KMeansModel.class.getName(), OffsetDateTime.now(),
401                examples.getProvenance(), trainerProvenance, runProvenance);
402
403        return new KMeansModel("k-means-model", provenance, featureMap, outputMap, centroidVectors, dist);
404    }
405
406    @Override
407    public KMeansModel train(Dataset<ClusterID> dataset) {
408        return train(dataset, Collections.emptyMap());
409    }
410
411    @Override
412    public int getInvocationCount() {
413        return trainInvocationCounter;
414    }
415
416    @Override
417    public synchronized void setInvocationCount(int invocationCount){
418        if(invocationCount < 0){
419            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
420        }
421
422        rng = new SplittableRandom(seed);
423
424        for (trainInvocationCounter = 0; trainInvocationCounter < invocationCount; trainInvocationCounter++){
425            SplittableRandom localRNG = rng.split();
426        }
427
428    }
429
430    /**
431     * Initialisation method called at the start of each train call when using the default centroid initialisation.
432     * Centroids are initialised using a uniform random sample from the feature domain.
433     *
434     * @param centroids  The number of centroids to create.
435     * @param featureMap The feature map to use for centroid sampling.
436     * @param rng The RNG to use.
437     * @return A {@link DenseVector} array of centroids.
438     */
439    private static DenseVector[] initialiseRandomCentroids(int centroids, ImmutableFeatureMap featureMap,
440                                                           SplittableRandom rng) {
441        DenseVector[] centroidVectors = new DenseVector[centroids];
442        int numFeatures = featureMap.size();
443        for (int i = 0; i < centroids; i++) {
444            double[] newCentroid = new double[numFeatures];
445
446            for (int j = 0; j < numFeatures; j++) {
447                newCentroid[j] = featureMap.get(j).uniformSample(rng);
448            }
449
450            centroidVectors[i] = DenseVector.createDenseVector(newCentroid);
451        }
452        return centroidVectors;
453    }
454
455    /**
456     * Initialisation method called at the start of each train call when using kmeans++ centroid initialisation.
457     *
458     * @param centroids The number of centroids to create.
459     * @param data The dataset of {@link SGDVector} to use.
460     * @param rng The RNG to use.
461     * @param dist The distance function.
462     * @return A {@link DenseVector} array of centroids.
463     */
464    private static DenseVector[] initialisePlusPlusCentroids(int centroids, SGDVector[] data, SplittableRandom rng,
465                                                             org.tribuo.math.distance.Distance dist) {
466        if (centroids > data.length) {
467            throw new IllegalArgumentException("The number of centroids may not exceed the number of samples.");
468        }
469
470        double[] minDistancePerVector = new double[data.length];
471        Arrays.fill(minDistancePerVector, Double.POSITIVE_INFINITY);
472
473        double[] squaredMinDistance = new double[data.length];
474        double[] probabilities = new double[data.length];
475        DenseVector[] centroidVectors = new DenseVector[centroids];
476
477        // set first centroid randomly from the data
478        centroidVectors[0] = getRandomCentroidFromData(data, rng);
479
480        // Set each uninitialised centroid remaining
481        for (int i = 1; i < centroids; i++) {
482            DenseVector prevCentroid = centroidVectors[i - 1];
483
484            // go through every vector and see if the min distance to the
485            // newest centroid is smaller than previous min distance for vec
486            for (int j = 0; j < data.length; j++) {
487                double tempDistance = dist.computeDistance(prevCentroid, data[j]);
488                minDistancePerVector[j] = Math.min(minDistancePerVector[j], tempDistance);
489            }
490
491            // square the distances and get total for normalization
492            double total = 0.0;
493            for (int j = 0; j < data.length; j++) {
494                squaredMinDistance[j] = minDistancePerVector[j] * minDistancePerVector[j];
495                total += squaredMinDistance[j];
496            }
497
498            // compute probabilities as p[i] = D(xi)^2 / sum(D(x)^2)
499            for (int j = 0; j < probabilities.length; j++) {
500                probabilities[j] = squaredMinDistance[j] / total;
501            }
502
503            // sample from probabilities to get the new centroid from data
504            double[] cdf = Util.generateCDF(probabilities);
505            int idx = Util.sampleFromCDF(cdf, rng);
506            centroidVectors[i] = DenseVector.createDenseVector(data[idx].toArray());
507        }
508        return centroidVectors;
509    }
510
511    /**
512     * Randomly select a piece of data as the starting centroid.
513     *
514     * @param data The dataset of {@link SparseVector} to use.
515     * @param rng The RNG to use.
516     * @return A {@link DenseVector} representing a centroid.
517     */
518    private static DenseVector getRandomCentroidFromData(SGDVector[] data, SplittableRandom rng) {
519        int randIdx = rng.nextInt(data.length);
520        return DenseVector.createDenseVector(data[randIdx].toArray());
521    }
522
523    /**
524     * Runs the mStep, writing to the {@code centroidVectors} array.
525     * <p>
526     * Note in 4.2 this method changed signature slightly, and overrides of the old
527     * version will not match.
528     * @param fjp The ForkJoinPool to run the computation in if it should be executed in parallel.
529     *            If the fjp is null then the computation is executed sequentially on the main thread.
530     * @param centroidVectors The centroid vectors to write out.
531     * @param clusterAssignments The current cluster assignments.
532     * @param data The data points.
533     * @param weights The example weights.
534     */
535    protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map<Integer, List<Integer>> clusterAssignments, SGDVector[] data, double[] weights) {
536        // M step
537        Consumer<Entry<Integer, List<Integer>>> mStepFunc = (e) -> {
538            DenseVector newCentroid = centroidVectors[e.getKey()];
539            newCentroid.fill(0.0);
540
541            double weightSum = 0.0;
542            for (Integer idx : e.getValue()) {
543                newCentroid.intersectAndAddInPlace(data[idx], (double f) -> f * weights[idx]);
544                weightSum += weights[idx];
545            }
546            if (weightSum != 0.0) {
547                newCentroid.scaleInPlace(1.0 / weightSum);
548            }
549        };
550
551        Stream<Entry<Integer, List<Integer>>> mStream = clusterAssignments.entrySet().stream();
552        if (fjp != null) {
553            Stream<Entry<Integer, List<Integer>>> parallelMStream = StreamUtil.boundParallelism(mStream.parallel());
554            try {
555                fjp.submit(() -> parallelMStream.forEach(mStepFunc)).get();
556            } catch (InterruptedException | ExecutionException e) {
557                throw new RuntimeException("Parallel execution failed", e);
558            }
559        } else {
560            mStream.forEach(mStepFunc);
561        }
562    }
563
564    @Override
565    public String toString() {
566        return "KMeansTrainer(centroids=" + centroids + ",distance=" + dist + ",seed=" + seed + ",numThreads=" + numThreads + ", initialisationType=" + initialisationType + ")";
567    }
568
569    @Override
570    public TrainerProvenance getProvenance() {
571        return new TrainerProvenanceImpl(this);
572    }
573
574    /**
575     * Tuple of index and position. One day it'll be a record, but not today.
576     */
577    static class IntAndVector {
578        final int idx;
579        final SGDVector vector;
580
581        /**
582         * Constructs an index and vector tuple.
583         * @param idx The index.
584         * @param vector The vector.
585         */
586        public IntAndVector(int idx, SGDVector vector) {
587            this.idx = idx;
588            this.vector = vector;
589        }
590    }
591
592    /**
593     * Used to allow FJPs to work with OpenSearch's SecureSM.
594     */
595    private static final class CustomForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory {
596        public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
597            return AccessController.doPrivileged((PrivilegedAction<ForkJoinWorkerThread>) () -> new ForkJoinWorkerThread(pool) {});
598        }
599    }
600}