001/* 002 * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.clustering.kmeans; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.config.PropertyException; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import com.oracle.labs.mlrg.olcut.util.MutableLong; 023import com.oracle.labs.mlrg.olcut.util.StreamUtil; 024import org.tribuo.Dataset; 025import org.tribuo.Example; 026import org.tribuo.ImmutableFeatureMap; 027import org.tribuo.ImmutableOutputInfo; 028import org.tribuo.Trainer; 029import org.tribuo.clustering.ClusterID; 030import org.tribuo.clustering.ImmutableClusteringInfo; 031import org.tribuo.math.distance.DistanceType; 032import org.tribuo.math.la.DenseVector; 033import org.tribuo.math.la.SGDVector; 034import org.tribuo.math.la.SparseVector; 035import org.tribuo.provenance.ModelProvenance; 036import org.tribuo.provenance.TrainerProvenance; 037import org.tribuo.provenance.impl.TrainerProvenanceImpl; 038import org.tribuo.util.Util; 039 040import java.security.AccessController; 041import java.security.PrivilegedAction; 042import java.time.OffsetDateTime; 043import java.util.ArrayList; 044import java.util.Arrays; 045import java.util.Collections; 046import java.util.HashMap; 047import java.util.List; 048import java.util.Map; 049import java.util.Map.Entry; 050import java.util.SplittableRandom; 051import java.util.concurrent.ExecutionException; 052import java.util.concurrent.ForkJoinPool; 053import java.util.concurrent.ForkJoinWorkerThread; 054import java.util.concurrent.atomic.AtomicInteger; 055import java.util.function.Consumer; 056import java.util.logging.Level; 057import java.util.logging.Logger; 058import java.util.stream.IntStream; 059import java.util.stream.Stream; 060 061/** 062 * A K-Means trainer, which generates a K-means clustering of the supplied 063 * data. The model finds the centres, and then predict needs to be 064 * called to infer the centre assignments for the input data. 065 * <p> 066 * It's slightly contorted to fit the Tribuo Trainer and Model API, as the cluster assignments 067 * can only be retrieved from the model after training, and require re-evaluating each example. 068 * <p> 069 * The Trainer has a parameterised distance function, and a selectable number 070 * of threads used in the training step. The thread pool is local to an invocation of train, 071 * so there can be multiple concurrent trainings. 072 * <p> 073 * The train method will instantiate dense examples as dense vectors, speeding up the computation. 074 * <p> 075 * Note parallel training uses a {@link ForkJoinPool} which requires that the Tribuo codebase 076 * is given the "modifyThread" and "modifyThreadGroup" privileges when running under a 077 * {@link java.lang.SecurityManager}. 078 * <p> 079 * See: 080 * <pre> 081 * J. Friedman, T. Hastie, & R. Tibshirani. 082 * "The Elements of Statistical Learning" 083 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a> 084 * </pre> 085 * <p> 086 * For more on optional kmeans++ initialisation, see: 087 * <pre> 088 * D. Arthur, S. Vassilvitskii. 089 * "K-Means++: The Advantages of Careful Seeding" 090 * <a href="https://theory.stanford.edu/~sergei/papers/kMeansPP-soda">PDF</a> 091 * </pre> 092 */ 093public class KMeansTrainer implements Trainer<ClusterID> { 094 private static final Logger logger = Logger.getLogger(KMeansTrainer.class.getName()); 095 096 // Thread factory for the FJP, to allow use with OpenSearch's SecureSM 097 private static final CustomForkJoinWorkerThreadFactory THREAD_FACTORY = new CustomForkJoinWorkerThreadFactory(); 098 099 /** 100 * Possible distance functions. 101 * @deprecated 102 * This Enum is deprecated in version 4.3, replaced by {@link DistanceType} 103 */ 104 @Deprecated 105 public enum Distance { 106 /** 107 * Euclidean (or l2) distance. 108 */ 109 EUCLIDEAN(DistanceType.L2), 110 /** 111 * Cosine similarity as a distance measure. 112 */ 113 COSINE(DistanceType.COSINE), 114 /** 115 * L1 (or Manhattan) distance. 116 */ 117 L1(DistanceType.L1); 118 119 private final DistanceType distanceType; 120 121 Distance(DistanceType distanceType) { 122 this.distanceType = distanceType; 123 } 124 125 /** 126 * Returns the {@link DistanceType} mapping for the enumeration's value. 127 * 128 * @return distanceType The {@link DistanceType} value. 129 */ 130 public DistanceType getDistanceType() { 131 return distanceType; 132 } 133 } 134 135 /** 136 * Possible initialization functions. 137 */ 138 public enum Initialisation { 139 /** 140 * Initialize centroids by choosing uniformly at random from the data 141 * points. 142 */ 143 RANDOM, 144 /** 145 * KMeans++ initialisation. 146 */ 147 PLUSPLUS 148 } 149 150 @Config(mandatory = true, description = "Number of centroids (i.e., the \"k\" in k-means).") 151 private int centroids; 152 153 @Config(mandatory = true, description = "The number of iterations to run.") 154 private int iterations; 155 156 @Deprecated 157 @Config(description = "The distance function to use. This is now deprecated.") 158 private Distance distanceType; 159 160 @Config(description = "The distance function to use.") 161 private org.tribuo.math.distance.Distance dist; 162 163 @Config(description = "The centroid initialisation method to use.") 164 private Initialisation initialisationType = Initialisation.RANDOM; 165 166 @Config(description = "The number of threads to use for training.") 167 private int numThreads = 1; 168 169 @Config(mandatory = true, description = "The seed to use for the RNG.") 170 private long seed; 171 172 private SplittableRandom rng; 173 174 private int trainInvocationCounter; 175 176 /** 177 * for olcut. 178 */ 179 private KMeansTrainer() { } 180 181 /** 182 * Constructs a K-Means trainer using the supplied parameters and the default random initialisation. 183 * @deprecated 184 * This Constructor is deprecated in version 4.3. 185 * 186 * @param centroids The number of centroids to use. 187 * @param iterations The maximum number of iterations. 188 * @param distanceType The distance function. 189 * @param numThreads The number of threads. 190 * @param seed The random seed. 191 */ 192 @Deprecated 193 public KMeansTrainer(int centroids, int iterations, Distance distanceType, int numThreads, long seed) { 194 this(centroids,iterations,distanceType,Initialisation.RANDOM,numThreads,seed); 195 } 196 197 /** 198 * Constructs a K-Means trainer using the supplied parameters and the default random initialisation. 199 * 200 * @param centroids The number of centroids to use. 201 * @param iterations The maximum number of iterations. 202 * @param dist The distance function. 203 * @param numThreads The number of threads. 204 * @param seed The random seed. 205 */ 206 public KMeansTrainer(int centroids, int iterations, org.tribuo.math.distance.Distance dist, int numThreads, long seed) { 207 this(centroids,iterations,dist,Initialisation.RANDOM,numThreads,seed); 208 } 209 210 /** 211 * Constructs a K-Means trainer using the supplied parameters. 212 * @deprecated 213 * This Constructor is deprecated in version 4.3. 214 * 215 * @param centroids The number of centroids to use. 216 * @param iterations The maximum number of iterations. 217 * @param distanceType The distance function. 218 * @param initialisationType The centroid initialization method. 219 * @param numThreads The number of threads. 220 * @param seed The random seed. 221 */ 222 @Deprecated 223 public KMeansTrainer(int centroids, int iterations, Distance distanceType, Initialisation initialisationType, int numThreads, long seed) { 224 this(centroids, iterations, distanceType.getDistanceType().getDistance(), initialisationType, numThreads, seed); 225 } 226 227 /** 228 * Constructs a K-Means trainer using the supplied parameters. 229 * 230 * @param centroids The number of centroids to use. 231 * @param iterations The maximum number of iterations. 232 * @param dist The distance function. 233 * @param initialisationType The centroid initialization method. 234 * @param numThreads The number of threads. 235 * @param seed The random seed. 236 */ 237 public KMeansTrainer(int centroids, int iterations, org.tribuo.math.distance.Distance dist, Initialisation initialisationType, int numThreads, long seed) { 238 this.centroids = centroids; 239 this.iterations = iterations; 240 this.dist = dist; 241 this.initialisationType = initialisationType; 242 this.numThreads = numThreads; 243 this.seed = seed; 244 postConfig(); 245 } 246 247 /** 248 * Used by the OLCUT configuration system, and should not be called by external code. 249 */ 250 @Override 251 public synchronized void postConfig() { 252 this.rng = new SplittableRandom(seed); 253 254 if (this.distanceType != null) { 255 if (this.dist != null) { 256 throw new PropertyException("dist", "Both dist and distanceType must not both be set."); 257 } else { 258 this.dist = this.distanceType.getDistanceType().getDistance(); 259 this.distanceType = null; 260 } 261 } 262 263 if (centroids < 1) { 264 throw new PropertyException("centroids", "Centroids must be positive, found " + centroids); 265 } 266 } 267 268 @Override 269 public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) { 270 return train(examples, runProvenance, INCREMENT_INVOCATION_COUNT); 271 } 272 273 @Override 274 public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance, int invocationCount) { 275 // Creates a new local RNG and adds one to the invocation count. 276 TrainerProvenance trainerProvenance; 277 SplittableRandom localRNG; 278 synchronized (this) { 279 if(invocationCount != INCREMENT_INVOCATION_COUNT) { 280 setInvocationCount(invocationCount); 281 } 282 localRNG = rng.split(); 283 trainerProvenance = getProvenance(); 284 trainInvocationCounter++; 285 } 286 ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); 287 288 int[] oldCentre = new int[examples.size()]; 289 SGDVector[] data = new SGDVector[examples.size()]; 290 double[] weights = new double[examples.size()]; 291 int n = 0; 292 for (Example<ClusterID> example : examples) { 293 weights[n] = example.getWeight(); 294 if (example.size() == featureMap.size()) { 295 data[n] = DenseVector.createDenseVector(example, featureMap, false); 296 } else { 297 data[n] = SparseVector.createSparseVector(example, featureMap, false); 298 } 299 oldCentre[n] = -1; 300 n++; 301 } 302 303 DenseVector[] centroidVectors; 304 switch (initialisationType) { 305 case RANDOM: 306 centroidVectors = initialiseRandomCentroids(centroids, featureMap, localRNG); 307 break; 308 case PLUSPLUS: 309 centroidVectors = initialisePlusPlusCentroids(centroids, data, localRNG, dist); 310 break; 311 default: 312 throw new IllegalStateException("Unknown initialisation" + initialisationType); 313 } 314 315 Map<Integer, List<Integer>> clusterAssignments = new HashMap<>(); 316 boolean parallel = numThreads > 1; 317 for (int i = 0; i < centroids; i++) { 318 clusterAssignments.put(i, parallel ? Collections.synchronizedList(new ArrayList<>()) : new ArrayList<>()); 319 } 320 321 AtomicInteger changeCounter = new AtomicInteger(0); 322 Consumer<IntAndVector> eStepFunc = (IntAndVector e) -> { 323 double minDist = Double.POSITIVE_INFINITY; 324 int clusterID = -1; 325 int id = e.idx; 326 SGDVector vector = e.vector; 327 for (int j = 0; j < centroids; j++) { 328 DenseVector cluster = centroidVectors[j]; 329 double distance = dist.computeDistance(cluster, vector); 330 if (distance < minDist) { 331 minDist = distance; 332 clusterID = j; 333 } 334 } 335 336 clusterAssignments.get(clusterID).add(id); 337 if (oldCentre[id] != clusterID) { 338 // Changed the centroid of this vector. 339 oldCentre[id] = clusterID; 340 changeCounter.incrementAndGet(); 341 } 342 }; 343 344 boolean converged = false; 345 ForkJoinPool fjp = null; 346 try { 347 if (parallel) { 348 if (System.getSecurityManager() == null) { 349 fjp = new ForkJoinPool(numThreads); 350 } else { 351 fjp = new ForkJoinPool(numThreads, THREAD_FACTORY, null, false); 352 } 353 } 354 for (int i = 0; (i < iterations) && !converged; i++) { 355 logger.log(Level.FINE,"Beginning iteration " + i); 356 changeCounter.set(0); 357 358 for (Entry<Integer, List<Integer>> e : clusterAssignments.entrySet()) { 359 e.getValue().clear(); 360 } 361 362 // E step 363 Stream<SGDVector> vecStream = Arrays.stream(data); 364 Stream<Integer> intStream = IntStream.range(0, data.length).boxed(); 365 Stream<IntAndVector> zipStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new); 366 if (parallel) { 367 Stream<IntAndVector> parallelZipStream = StreamUtil.boundParallelism(zipStream.parallel()); 368 try { 369 fjp.submit(() -> parallelZipStream.forEach(eStepFunc)).get(); 370 } catch (InterruptedException | ExecutionException e) { 371 throw new RuntimeException("Parallel execution failed", e); 372 } 373 } else { 374 zipStream.forEach(eStepFunc); 375 } 376 logger.log(Level.FINE, "E step completed. " + changeCounter.get() + " words updated."); 377 378 mStep(fjp, centroidVectors, clusterAssignments, data, weights); 379 380 logger.log(Level.INFO, "Iteration " + i + " completed. " + changeCounter.get() + " examples updated."); 381 382 if (changeCounter.get() == 0) { 383 converged = true; 384 logger.log(Level.INFO, "K-Means converged at iteration " + i); 385 } 386 } 387 } finally { 388 if (fjp != null) { 389 fjp.shutdown(); 390 } 391 } 392 393 Map<Integer, MutableLong> counts = new HashMap<>(); 394 for (Entry<Integer, List<Integer>> e : clusterAssignments.entrySet()) { 395 counts.put(e.getKey(), new MutableLong(e.getValue().size())); 396 } 397 398 ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts); 399 400 ModelProvenance provenance = new ModelProvenance(KMeansModel.class.getName(), OffsetDateTime.now(), 401 examples.getProvenance(), trainerProvenance, runProvenance); 402 403 return new KMeansModel("k-means-model", provenance, featureMap, outputMap, centroidVectors, dist); 404 } 405 406 @Override 407 public KMeansModel train(Dataset<ClusterID> dataset) { 408 return train(dataset, Collections.emptyMap()); 409 } 410 411 @Override 412 public int getInvocationCount() { 413 return trainInvocationCounter; 414 } 415 416 @Override 417 public synchronized void setInvocationCount(int invocationCount){ 418 if(invocationCount < 0){ 419 throw new IllegalArgumentException("The supplied invocationCount is less than zero."); 420 } 421 422 rng = new SplittableRandom(seed); 423 424 for (trainInvocationCounter = 0; trainInvocationCounter < invocationCount; trainInvocationCounter++){ 425 SplittableRandom localRNG = rng.split(); 426 } 427 428 } 429 430 /** 431 * Initialisation method called at the start of each train call when using the default centroid initialisation. 432 * Centroids are initialised using a uniform random sample from the feature domain. 433 * 434 * @param centroids The number of centroids to create. 435 * @param featureMap The feature map to use for centroid sampling. 436 * @param rng The RNG to use. 437 * @return A {@link DenseVector} array of centroids. 438 */ 439 private static DenseVector[] initialiseRandomCentroids(int centroids, ImmutableFeatureMap featureMap, 440 SplittableRandom rng) { 441 DenseVector[] centroidVectors = new DenseVector[centroids]; 442 int numFeatures = featureMap.size(); 443 for (int i = 0; i < centroids; i++) { 444 double[] newCentroid = new double[numFeatures]; 445 446 for (int j = 0; j < numFeatures; j++) { 447 newCentroid[j] = featureMap.get(j).uniformSample(rng); 448 } 449 450 centroidVectors[i] = DenseVector.createDenseVector(newCentroid); 451 } 452 return centroidVectors; 453 } 454 455 /** 456 * Initialisation method called at the start of each train call when using kmeans++ centroid initialisation. 457 * 458 * @param centroids The number of centroids to create. 459 * @param data The dataset of {@link SGDVector} to use. 460 * @param rng The RNG to use. 461 * @param dist The distance function. 462 * @return A {@link DenseVector} array of centroids. 463 */ 464 private static DenseVector[] initialisePlusPlusCentroids(int centroids, SGDVector[] data, SplittableRandom rng, 465 org.tribuo.math.distance.Distance dist) { 466 if (centroids > data.length) { 467 throw new IllegalArgumentException("The number of centroids may not exceed the number of samples."); 468 } 469 470 double[] minDistancePerVector = new double[data.length]; 471 Arrays.fill(minDistancePerVector, Double.POSITIVE_INFINITY); 472 473 double[] squaredMinDistance = new double[data.length]; 474 double[] probabilities = new double[data.length]; 475 DenseVector[] centroidVectors = new DenseVector[centroids]; 476 477 // set first centroid randomly from the data 478 centroidVectors[0] = getRandomCentroidFromData(data, rng); 479 480 // Set each uninitialised centroid remaining 481 for (int i = 1; i < centroids; i++) { 482 DenseVector prevCentroid = centroidVectors[i - 1]; 483 484 // go through every vector and see if the min distance to the 485 // newest centroid is smaller than previous min distance for vec 486 for (int j = 0; j < data.length; j++) { 487 double tempDistance = dist.computeDistance(prevCentroid, data[j]); 488 minDistancePerVector[j] = Math.min(minDistancePerVector[j], tempDistance); 489 } 490 491 // square the distances and get total for normalization 492 double total = 0.0; 493 for (int j = 0; j < data.length; j++) { 494 squaredMinDistance[j] = minDistancePerVector[j] * minDistancePerVector[j]; 495 total += squaredMinDistance[j]; 496 } 497 498 // compute probabilities as p[i] = D(xi)^2 / sum(D(x)^2) 499 for (int j = 0; j < probabilities.length; j++) { 500 probabilities[j] = squaredMinDistance[j] / total; 501 } 502 503 // sample from probabilities to get the new centroid from data 504 double[] cdf = Util.generateCDF(probabilities); 505 int idx = Util.sampleFromCDF(cdf, rng); 506 centroidVectors[i] = DenseVector.createDenseVector(data[idx].toArray()); 507 } 508 return centroidVectors; 509 } 510 511 /** 512 * Randomly select a piece of data as the starting centroid. 513 * 514 * @param data The dataset of {@link SparseVector} to use. 515 * @param rng The RNG to use. 516 * @return A {@link DenseVector} representing a centroid. 517 */ 518 private static DenseVector getRandomCentroidFromData(SGDVector[] data, SplittableRandom rng) { 519 int randIdx = rng.nextInt(data.length); 520 return DenseVector.createDenseVector(data[randIdx].toArray()); 521 } 522 523 /** 524 * Runs the mStep, writing to the {@code centroidVectors} array. 525 * <p> 526 * Note in 4.2 this method changed signature slightly, and overrides of the old 527 * version will not match. 528 * @param fjp The ForkJoinPool to run the computation in if it should be executed in parallel. 529 * If the fjp is null then the computation is executed sequentially on the main thread. 530 * @param centroidVectors The centroid vectors to write out. 531 * @param clusterAssignments The current cluster assignments. 532 * @param data The data points. 533 * @param weights The example weights. 534 */ 535 protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map<Integer, List<Integer>> clusterAssignments, SGDVector[] data, double[] weights) { 536 // M step 537 Consumer<Entry<Integer, List<Integer>>> mStepFunc = (e) -> { 538 DenseVector newCentroid = centroidVectors[e.getKey()]; 539 newCentroid.fill(0.0); 540 541 double weightSum = 0.0; 542 for (Integer idx : e.getValue()) { 543 newCentroid.intersectAndAddInPlace(data[idx], (double f) -> f * weights[idx]); 544 weightSum += weights[idx]; 545 } 546 if (weightSum != 0.0) { 547 newCentroid.scaleInPlace(1.0 / weightSum); 548 } 549 }; 550 551 Stream<Entry<Integer, List<Integer>>> mStream = clusterAssignments.entrySet().stream(); 552 if (fjp != null) { 553 Stream<Entry<Integer, List<Integer>>> parallelMStream = StreamUtil.boundParallelism(mStream.parallel()); 554 try { 555 fjp.submit(() -> parallelMStream.forEach(mStepFunc)).get(); 556 } catch (InterruptedException | ExecutionException e) { 557 throw new RuntimeException("Parallel execution failed", e); 558 } 559 } else { 560 mStream.forEach(mStepFunc); 561 } 562 } 563 564 @Override 565 public String toString() { 566 return "KMeansTrainer(centroids=" + centroids + ",distance=" + dist + ",seed=" + seed + ",numThreads=" + numThreads + ", initialisationType=" + initialisationType + ")"; 567 } 568 569 @Override 570 public TrainerProvenance getProvenance() { 571 return new TrainerProvenanceImpl(this); 572 } 573 574 /** 575 * Tuple of index and position. One day it'll be a record, but not today. 576 */ 577 static class IntAndVector { 578 final int idx; 579 final SGDVector vector; 580 581 /** 582 * Constructs an index and vector tuple. 583 * @param idx The index. 584 * @param vector The vector. 585 */ 586 public IntAndVector(int idx, SGDVector vector) { 587 this.idx = idx; 588 this.vector = vector; 589 } 590 } 591 592 /** 593 * Used to allow FJPs to work with OpenSearch's SecureSM. 594 */ 595 private static final class CustomForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory { 596 public final ForkJoinWorkerThread newThread(ForkJoinPool pool) { 597 return AccessController.doPrivileged((PrivilegedAction<ForkJoinWorkerThread>) () -> new ForkJoinWorkerThread(pool) {}); 598 } 599 } 600}