001/* 002 * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.clustering.kmeans; 018 019import com.google.protobuf.Any; 020import com.google.protobuf.InvalidProtocolBufferException; 021import com.oracle.labs.mlrg.olcut.util.Pair; 022import org.tribuo.Example; 023import org.tribuo.Excuse; 024import org.tribuo.Feature; 025import org.tribuo.ImmutableFeatureMap; 026import org.tribuo.ImmutableOutputInfo; 027import org.tribuo.Model; 028import org.tribuo.Prediction; 029import org.tribuo.clustering.ClusterID; 030import org.tribuo.clustering.kmeans.KMeansTrainer.Distance; 031import org.tribuo.clustering.kmeans.protos.KMeansModelProto; 032import org.tribuo.impl.ModelDataCarrier; 033import org.tribuo.math.la.DenseVector; 034import org.tribuo.math.la.SGDVector; 035import org.tribuo.math.la.SparseVector; 036import org.tribuo.math.la.Tensor; 037import org.tribuo.math.la.VectorTuple; 038import org.tribuo.math.protos.TensorProto; 039import org.tribuo.protos.ProtoUtil; 040import org.tribuo.protos.core.ModelProto; 041import org.tribuo.provenance.ModelProvenance; 042 043import java.io.IOException; 044import java.util.ArrayList; 045import java.util.Collections; 046import java.util.List; 047import java.util.Map; 048import java.util.Optional; 049 050/** 051 * A K-Means model with a selectable distance function. 052 * <p> 053 * The predict method of this model assigns centres to the provided input, 054 * but it does not update the model's centroids. 055 * <p> 056 * The predict method is single threaded. 057 * <p> 058 * See: 059 * <pre> 060 * J. Friedman, T. Hastie, & R. Tibshirani. 061 * "The Elements of Statistical Learning" 062 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a> 063 * </pre> 064 */ 065public class KMeansModel extends Model<ClusterID> { 066 private static final long serialVersionUID = 1L; 067 068 /** 069 * Protobuf serialization version. 070 */ 071 public static final int CURRENT_VERSION = 0; 072 073 private final DenseVector[] centroidVectors; 074 075 @Deprecated 076 private Distance distanceType; 077 078 // This is not final to support deserialization of older models. It will be final in a future version which doesn't 079 // maintain serialization compatibility with 4.X. 080 private org.tribuo.math.distance.Distance dist; 081 082 KMeansModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, 083 ImmutableOutputInfo<ClusterID> outputIDInfo, DenseVector[] centroidVectors, org.tribuo.math.distance.Distance dist) { 084 super(name,description,featureIDMap,outputIDInfo,false); 085 this.centroidVectors = centroidVectors; 086 this.dist = dist; 087 } 088 089 /** 090 * Deserialization factory. 091 * @param version The serialized object version. 092 * @param className The class name. 093 * @param message The serialized data. 094 * @throws InvalidProtocolBufferException If the protobuf could not be parsed from the {@code message}. 095 * @return The deserialized object. 096 */ 097 public static KMeansModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException { 098 if (version < 0 || version > CURRENT_VERSION) { 099 throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION); 100 } 101 KMeansModelProto proto = message.unpack(KMeansModelProto.class); 102 103 ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata()); 104 if (!carrier.outputDomain().getOutput(0).getClass().equals(ClusterID.class)) { 105 throw new IllegalStateException("Invalid protobuf, output domain is not a clustering domain, found " + carrier.outputDomain().getClass()); 106 } 107 @SuppressWarnings("unchecked") // guarded by getClass 108 ImmutableOutputInfo<ClusterID> outputDomain = (ImmutableOutputInfo<ClusterID>) carrier.outputDomain(); 109 110 ImmutableFeatureMap featureDomain = carrier.featureDomain(); 111 112 if (proto.getCentroidVectorsCount() == 0) { 113 throw new IllegalStateException("Invalid protobuf, no centroids were found"); 114 } 115 DenseVector[] centroids = new DenseVector[proto.getCentroidVectorsCount()]; 116 List<TensorProto> centroidProtos = proto.getCentroidVectorsList(); 117 for (int i = 0; i < centroids.length; i++) { 118 Tensor centroidTensor = Tensor.deserialize(centroidProtos.get(i)); 119 if (centroidTensor instanceof DenseVector) { 120 DenseVector centroid = (DenseVector) centroidTensor; 121 if (centroid.size() != featureDomain.size()) { 122 throw new IllegalStateException("Invalid protobuf, centroid did not contain all the features, found " + centroid.size() + " expected " + featureDomain.size()); 123 } 124 centroids[i] = centroid; 125 } else { 126 throw new IllegalStateException("Invalid protobuf, expected centroid to be a dense vector, found " + centroidTensor.getClass()); 127 } 128 } 129 130 org.tribuo.math.distance.Distance dist = ProtoUtil.deserialize(proto.getDistance()); 131 132 return new KMeansModel(carrier.name(), carrier.provenance(), featureDomain, outputDomain, centroids, dist); 133 } 134 135 /** 136 * Returns a copy of the centroids. 137 * <p> 138 * In most cases you should prefer {@link #getCentroids} as 139 * it performs the mapping from Tribuo's internal feature ids 140 * to the externally visible feature names for you. 141 * This method provides direct access to the centroid vectors 142 * for use in downstream processing if the ids are not relevant 143 * (or are known to match). 144 * @return The centroids. 145 */ 146 public DenseVector[] getCentroidVectors() { 147 DenseVector[] copies = new DenseVector[centroidVectors.length]; 148 149 for (int i = 0; i < copies.length; i++) { 150 copies[i] = centroidVectors[i].copy(); 151 } 152 153 return copies; 154 } 155 156 /** 157 * Returns a list of features, one per centroid. 158 * <p> 159 * This should be used in preference to {@link #getCentroidVectors()} 160 * as it performs the mapping from Tribuo's internal feature ids to 161 * the externally visible feature names. 162 * </p> 163 * @return A list containing all the centroids. 164 */ 165 public List<List<Feature>> getCentroids() { 166 List<List<Feature>> output = new ArrayList<>(centroidVectors.length); 167 168 for (int i = 0; i < centroidVectors.length; i++) { 169 List<Feature> features = new ArrayList<>(featureIDMap.size()); 170 171 for (VectorTuple v : centroidVectors[i]) { 172 Feature f = new Feature(featureIDMap.get(v.index).getName(),v.value); 173 features.add(f); 174 } 175 176 output.add(features); 177 } 178 179 return output; 180 } 181 182 @Override 183 public Prediction<ClusterID> predict(Example<ClusterID> example) { 184 SGDVector vector; 185 if (example.size() == featureIDMap.size()) { 186 vector = DenseVector.createDenseVector(example, featureIDMap, false); 187 } else { 188 vector = SparseVector.createSparseVector(example, featureIDMap, false); 189 } 190 if (vector.numActiveElements() == 0) { 191 throw new IllegalArgumentException("No features found in Example " + example.toString()); 192 } 193 double minDistance = Double.POSITIVE_INFINITY; 194 int id = -1; 195 for (int i = 0; i < centroidVectors.length; i++) { 196 double distance = dist.computeDistance(centroidVectors[i], vector); 197 198 if (distance < minDistance) { 199 minDistance = distance; 200 id = i; 201 } 202 } 203 return new Prediction<>(new ClusterID(id),vector.size(),example); 204 } 205 206 @Override 207 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 208 return Collections.emptyMap(); 209 } 210 211 @Override 212 public Optional<Excuse<ClusterID>> getExcuse(Example<ClusterID> example) { 213 return Optional.empty(); 214 } 215 216 @Override 217 public ModelProto serialize() { 218 ModelDataCarrier<ClusterID> carrier = createDataCarrier(); 219 220 KMeansModelProto.Builder modelBuilder = KMeansModelProto.newBuilder(); 221 modelBuilder.setMetadata(carrier.serialize()); 222 modelBuilder.setDistance(dist.serialize()); 223 for (DenseVector e : centroidVectors) { 224 modelBuilder.addCentroidVectors(e.serialize()); 225 } 226 227 ModelProto.Builder builder = ModelProto.newBuilder(); 228 builder.setSerializedData(Any.pack(modelBuilder.build())); 229 builder.setClassName(KMeansModel.class.getName()); 230 builder.setVersion(CURRENT_VERSION); 231 232 return builder.build(); 233 } 234 235 @Override 236 protected KMeansModel copy(String newName, ModelProvenance newProvenance) { 237 DenseVector[] newCentroids = new DenseVector[centroidVectors.length]; 238 for (int i = 0; i < centroidVectors.length; i++) { 239 newCentroids[i] = centroidVectors[i].copy(); 240 } 241 return new KMeansModel(newName,newProvenance,featureIDMap,outputIDInfo,newCentroids,dist); 242 } 243 244 private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException { 245 in.defaultReadObject(); 246 if (dist == null) { 247 dist = distanceType.getDistanceType().getDistance(); 248 } 249 } 250}