001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.clustering.kmeans; 018 019import com.oracle.labs.mlrg.olcut.config.Option; 020import com.oracle.labs.mlrg.olcut.config.Options; 021import org.tribuo.Trainer; 022import org.tribuo.clustering.kmeans.KMeansTrainer.Initialisation; 023import org.tribuo.math.distance.DistanceType; 024 025import java.util.logging.Logger; 026 027/** 028 * OLCUT {@link Options} for the K-Means implementation. 029 */ 030public class KMeansOptions implements Options { 031 private static final Logger logger = Logger.getLogger(KMeansOptions.class.getName()); 032 033 /** 034 * Iterations of the k-means algorithm. Defaults to 10. 035 */ 036 @Option(longName = "kmeans-interations", usage = "Iterations of the k-means algorithm. Defaults to 10.") 037 public int iterations = 10; 038 /** 039 * Number of centroids in K-Means. Defaults to 10. 040 */ 041 @Option(longName = "kmeans-num-centroids", usage = "Number of centroids in K-Means. Defaults to 10.") 042 public int centroids = 10; 043 /** 044 * Distance function in K-Means. Defaults to L2 (EUCLIDEAN). 045 */ 046 @Option(longName = "kmeans-distance-type", usage = "The type of distance function to use for various distance calculations.") 047 public DistanceType distType = DistanceType.L2; 048 /** 049 * Initialisation function in K-Means. Defaults to RANDOM. 050 */ 051 @Option(longName = "kmeans-initialisation", usage = "Initialisation function in K-Means. Defaults to RANDOM.") 052 public Initialisation initialisation = Initialisation.RANDOM; 053 /** 054 * Number of computation threads in K-Means. Defaults to 4. 055 */ 056 @Option(longName = "kmeans-num-threads", usage = "Number of computation threads in K-Means. Defaults to 4.") 057 public int numThreads = 4; 058 @Option(longName = "kmeans-seed", usage = "Sets the random seed for K-Means.") 059 private long seed = Trainer.DEFAULT_SEED; 060 061 /** 062 * Gets the configured KMeansTrainer using the options in this object. 063 * @return A KMeansTrainer. 064 */ 065 public KMeansTrainer getTrainer() { 066 logger.info("Configuring K-Means Trainer"); 067 return new KMeansTrainer(centroids, iterations, distType.getDistance(), initialisation, numThreads, seed); 068 } 069}