001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.clustering.kmeans; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 024import com.oracle.labs.mlrg.olcut.util.Pair; 025import org.tribuo.Dataset; 026import org.tribuo.Model; 027import org.tribuo.clustering.ClusterID; 028import org.tribuo.clustering.ClusteringFactory; 029import org.tribuo.clustering.evaluation.ClusteringEvaluation; 030import org.tribuo.clustering.kmeans.KMeansTrainer.Initialisation; 031import org.tribuo.data.DataOptions; 032import org.tribuo.math.distance.DistanceType; 033 034import java.io.IOException; 035import java.util.logging.Logger; 036 037 038/** 039 * Build and run a k-means clustering model for a standard dataset. 040 */ 041public class TrainTest { 042 043 private static final Logger logger = Logger.getLogger(TrainTest.class.getName()); 044 045 /** 046 * Options for the K-Means CLI. 047 */ 048 public static class KMeansOptions implements Options { 049 @Override 050 public String getOptionsDescription() { 051 return "Trains and evaluates a K-Means model on the specified dataset."; 052 } 053 054 /** 055 * The data loading options. 056 */ 057 public DataOptions general; 058 059 /** 060 * Number of clusters to infer. 061 */ 062 @Option(charName = 'n', longName = "num-clusters", usage = "Number of clusters to infer.") 063 public int centroids = 5; 064 /** 065 * Maximum number of iterations. 066 */ 067 @Option(charName = 'i', longName = "iterations", usage = "Maximum number of iterations.") 068 public int iterations = 10; 069 /** 070 * Distance function to use in the e step. 071 */ 072 @Option(charName = 'd', longName = "distance-type", usage = "Distance function to use in the e step.") 073 public DistanceType distType = DistanceType.L2; 074 /** 075 * Type of initialisation to use for centroids. 076 */ 077 @Option(charName = 's', longName = "initialisation", usage = "Type of initialisation to use for centroids.") 078 public Initialisation initialisation = Initialisation.RANDOM; 079 /** 080 * Number of threads to use (range (1, num hw threads)). 081 */ 082 @Option(charName = 't', longName = "num-threads", usage = "Number of threads to use (range (1, num hw threads)).") 083 public int numThreads = 4; 084 } 085 086 /** 087 * Runs a TrainTest CLI. 088 * @param args the command line arguments 089 * @throws IOException if there is any error reading the examples. 090 */ 091 public static void main(String[] args) throws IOException { 092 // 093 // Use the labs format logging. 094 LabsLogFormatter.setAllLogFormatters(); 095 096 KMeansOptions o = new KMeansOptions(); 097 ConfigurationManager cm; 098 try { 099 cm = new ConfigurationManager(args,o); 100 } catch (UsageException e) { 101 logger.info(e.getMessage()); 102 return; 103 } 104 105 if (o.general.trainingPath == null) { 106 logger.info(cm.usage()); 107 return; 108 } 109 110 ClusteringFactory factory = new ClusteringFactory(); 111 112 Pair<Dataset<ClusterID>,Dataset<ClusterID>> data = o.general.load(factory); 113 Dataset<ClusterID> train = data.getA(); 114 115 //public KMeansTrainer(int centroids, int iterations, DistanceType distType, int numThreads, int seed) 116 KMeansTrainer trainer = new KMeansTrainer(o.centroids,o.iterations, 117 o.distType.getDistance(),o.initialisation,o.numThreads,o.general.seed); 118 Model<ClusterID> model = trainer.train(train); 119 logger.info("Finished training model"); 120 ClusteringEvaluation evaluation = factory.getEvaluator().evaluate(model,train); 121 logger.info("Finished evaluating model"); 122 System.out.println("Normalized MI = " + evaluation.normalizedMI()); 123 System.out.println("Adjusted MI = " + evaluation.adjustedMI()); 124 125 if (o.general.outputPath != null) { 126 o.general.saveModel(model); 127 } 128 } 129}