001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.clustering.kmeans;
018
019import com.oracle.labs.mlrg.olcut.config.Option;
020import com.oracle.labs.mlrg.olcut.config.Options;
021import org.tribuo.Trainer;
022import org.tribuo.clustering.kmeans.KMeansTrainer.Initialisation;
023import org.tribuo.math.distance.DistanceType;
024
025import java.util.logging.Logger;
026
027/**
028 * OLCUT {@link Options} for the K-Means implementation.
029 */
030public class KMeansOptions implements Options {
031    private static final Logger logger = Logger.getLogger(KMeansOptions.class.getName());
032
033    /**
034     * Iterations of the k-means algorithm. Defaults to 10.
035     */
036    @Option(longName = "kmeans-interations", usage = "Iterations of the k-means algorithm. Defaults to 10.")
037    public int iterations = 10;
038    /**
039     * Number of centroids in K-Means. Defaults to 10.
040     */
041    @Option(longName = "kmeans-num-centroids", usage = "Number of centroids in K-Means. Defaults to 10.")
042    public int centroids = 10;
043    /**
044     * Distance function in K-Means. Defaults to L2 (EUCLIDEAN).
045     */
046    @Option(longName = "kmeans-distance-type", usage = "The type of distance function to use for various distance calculations.")
047    public DistanceType distType = DistanceType.L2;
048    /**
049     * Initialisation function in K-Means. Defaults to RANDOM.
050     */
051    @Option(longName = "kmeans-initialisation", usage = "Initialisation function in K-Means. Defaults to RANDOM.")
052    public Initialisation initialisation = Initialisation.RANDOM;
053    /**
054     * Number of computation threads in K-Means. Defaults to 4.
055     */
056    @Option(longName = "kmeans-num-threads", usage = "Number of computation threads in K-Means. Defaults to 4.")
057    public int numThreads = 4;
058    @Option(longName = "kmeans-seed", usage = "Sets the random seed for K-Means.")
059    private long seed = Trainer.DEFAULT_SEED;
060
061    /**
062     * Gets the configured KMeansTrainer using the options in this object.
063     * @return A KMeansTrainer.
064     */
065    public KMeansTrainer getTrainer() {
066        logger.info("Configuring K-Means Trainer");
067        return new KMeansTrainer(centroids, iterations, distType.getDistance(), initialisation, numThreads, seed);
068    }
069}