Class KMeansTrainer

java.lang.Object
org.tribuo.clustering.kmeans.KMeansTrainer
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>, Trainer<ClusterID>

public class KMeansTrainer extends Object implements Trainer<ClusterID>
A K-Means trainer, which generates a K-means clustering of the supplied data. The model finds the centres, and then predict needs to be called to infer the centre assignments for the input data.

It's slightly contorted to fit the Tribuo Trainer and Model API, as the cluster assignments can only be retrieved from the model after training, and require re-evaluating each example.

The Trainer has a parameterised distance function, and a selectable number of threads used in the training step. The thread pool is local to an invocation of train, so there can be multiple concurrent trainings.

The train method will instantiate dense examples as dense vectors, speeding up the computation.

Note parallel training uses a ForkJoinPool which requires that the Tribuo codebase is given the "modifyThread" and "modifyThreadGroup" privileges when running under a SecurityManager.

See:

 J. Friedman, T. Hastie, & R. Tibshirani.
 "The Elements of Statistical Learning"
 Springer 2001. PDF
 

For more on optional kmeans++ initialisation, see:

 D. Arthur, S. Vassilvitskii.
 "K-Means++: The Advantages of Careful Seeding"
 PDF
 
  • Constructor Details

    • KMeansTrainer

      @Deprecated public KMeansTrainer(int centroids, int iterations, KMeansTrainer.Distance distanceType, int numThreads, long seed)
      Deprecated.
      This Constructor is deprecated in version 4.3.
      Constructs a K-Means trainer using the supplied parameters and the default random initialisation.
      Parameters:
      centroids - The number of centroids to use.
      iterations - The maximum number of iterations.
      distanceType - The distance function.
      numThreads - The number of threads.
      seed - The random seed.
    • KMeansTrainer

      public KMeansTrainer(int centroids, int iterations, Distance dist, int numThreads, long seed)
      Constructs a K-Means trainer using the supplied parameters and the default random initialisation.
      Parameters:
      centroids - The number of centroids to use.
      iterations - The maximum number of iterations.
      dist - The distance function.
      numThreads - The number of threads.
      seed - The random seed.
    • KMeansTrainer

      @Deprecated public KMeansTrainer(int centroids, int iterations, KMeansTrainer.Distance distanceType, KMeansTrainer.Initialisation initialisationType, int numThreads, long seed)
      Deprecated.
      This Constructor is deprecated in version 4.3.
      Constructs a K-Means trainer using the supplied parameters.
      Parameters:
      centroids - The number of centroids to use.
      iterations - The maximum number of iterations.
      distanceType - The distance function.
      initialisationType - The centroid initialization method.
      numThreads - The number of threads.
      seed - The random seed.
    • KMeansTrainer

      public KMeansTrainer(int centroids, int iterations, Distance dist, KMeansTrainer.Initialisation initialisationType, int numThreads, long seed)
      Constructs a K-Means trainer using the supplied parameters.
      Parameters:
      centroids - The number of centroids to use.
      iterations - The maximum number of iterations.
      dist - The distance function.
      initialisationType - The centroid initialization method.
      numThreads - The number of threads.
      seed - The random seed.
  • Method Details