001/*
002 * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.clustering.kmeans;
018
019import com.google.protobuf.Any;
020import com.google.protobuf.InvalidProtocolBufferException;
021import com.oracle.labs.mlrg.olcut.util.Pair;
022import org.tribuo.Example;
023import org.tribuo.Excuse;
024import org.tribuo.Feature;
025import org.tribuo.ImmutableFeatureMap;
026import org.tribuo.ImmutableOutputInfo;
027import org.tribuo.Model;
028import org.tribuo.Prediction;
029import org.tribuo.clustering.ClusterID;
030import org.tribuo.clustering.kmeans.KMeansTrainer.Distance;
031import org.tribuo.clustering.kmeans.protos.KMeansModelProto;
032import org.tribuo.impl.ModelDataCarrier;
033import org.tribuo.math.la.DenseVector;
034import org.tribuo.math.la.SGDVector;
035import org.tribuo.math.la.SparseVector;
036import org.tribuo.math.la.Tensor;
037import org.tribuo.math.la.VectorTuple;
038import org.tribuo.math.protos.TensorProto;
039import org.tribuo.protos.ProtoUtil;
040import org.tribuo.protos.core.ModelProto;
041import org.tribuo.provenance.ModelProvenance;
042
043import java.io.IOException;
044import java.util.ArrayList;
045import java.util.Collections;
046import java.util.List;
047import java.util.Map;
048import java.util.Optional;
049
050/**
051 * A K-Means model with a selectable distance function.
052 * <p>
053 * The predict method of this model assigns centres to the provided input,
054 * but it does not update the model's centroids.
055 * <p>
056 * The predict method is single threaded.
057 * <p>
058 * See:
059 * <pre>
060 * J. Friedman, T. Hastie, &amp; R. Tibshirani.
061 * "The Elements of Statistical Learning"
062 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a>
063 * </pre>
064 */
065public class KMeansModel extends Model<ClusterID> {
066    private static final long serialVersionUID = 1L;
067
068    /**
069     * Protobuf serialization version.
070     */
071    public static final int CURRENT_VERSION = 0;
072
073    private final DenseVector[] centroidVectors;
074
075    @Deprecated
076    private Distance distanceType;
077
078    // This is not final to support deserialization of older models. It will be final in a future version which doesn't
079    // maintain serialization compatibility with 4.X.
080    private org.tribuo.math.distance.Distance dist;
081
082    KMeansModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap,
083                ImmutableOutputInfo<ClusterID> outputIDInfo, DenseVector[] centroidVectors, org.tribuo.math.distance.Distance dist) {
084        super(name,description,featureIDMap,outputIDInfo,false);
085        this.centroidVectors = centroidVectors;
086        this.dist = dist;
087    }
088
089    /**
090     * Deserialization factory.
091     * @param version The serialized object version.
092     * @param className The class name.
093     * @param message The serialized data.
094     * @throws InvalidProtocolBufferException If the protobuf could not be parsed from the {@code message}.
095     * @return The deserialized object.
096     */
097    public static KMeansModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
098        if (version < 0 || version > CURRENT_VERSION) {
099            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
100        }
101        KMeansModelProto proto = message.unpack(KMeansModelProto.class);
102
103        ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
104        if (!carrier.outputDomain().getOutput(0).getClass().equals(ClusterID.class)) {
105            throw new IllegalStateException("Invalid protobuf, output domain is not a clustering domain, found " + carrier.outputDomain().getClass());
106        }
107        @SuppressWarnings("unchecked") // guarded by getClass
108        ImmutableOutputInfo<ClusterID> outputDomain = (ImmutableOutputInfo<ClusterID>) carrier.outputDomain();
109
110        ImmutableFeatureMap featureDomain = carrier.featureDomain();
111
112        if (proto.getCentroidVectorsCount() == 0) {
113            throw new IllegalStateException("Invalid protobuf, no centroids were found");
114        }
115        DenseVector[] centroids = new DenseVector[proto.getCentroidVectorsCount()];
116        List<TensorProto> centroidProtos = proto.getCentroidVectorsList();
117        for (int i = 0; i < centroids.length; i++) {
118            Tensor centroidTensor = Tensor.deserialize(centroidProtos.get(i));
119            if (centroidTensor instanceof DenseVector) {
120                DenseVector centroid = (DenseVector) centroidTensor;
121                if (centroid.size() != featureDomain.size()) {
122                    throw new IllegalStateException("Invalid protobuf, centroid did not contain all the features, found " + centroid.size() + " expected " + featureDomain.size());
123                }
124                centroids[i] = centroid;
125            } else {
126                throw new IllegalStateException("Invalid protobuf, expected centroid to be a dense vector, found " + centroidTensor.getClass());
127            }
128        }
129
130        org.tribuo.math.distance.Distance dist = ProtoUtil.deserialize(proto.getDistance());
131
132        return new KMeansModel(carrier.name(), carrier.provenance(), featureDomain, outputDomain, centroids, dist);
133    }
134
135    /**
136     * Returns a copy of the centroids.
137     * <p>
138     * In most cases you should prefer {@link #getCentroids} as
139     * it performs the mapping from Tribuo's internal feature ids
140     * to the externally visible feature names for you.
141     * This method provides direct access to the centroid vectors
142     * for use in downstream processing if the ids are not relevant
143     * (or are known to match).
144     * @return The centroids.
145     */
146    public DenseVector[] getCentroidVectors() {
147        DenseVector[] copies = new DenseVector[centroidVectors.length];
148
149        for (int i = 0; i < copies.length; i++) {
150            copies[i] = centroidVectors[i].copy();
151        }
152
153        return copies;
154    }
155
156    /**
157     * Returns a list of features, one per centroid.
158     * <p>
159     * This should be used in preference to {@link #getCentroidVectors()}
160     * as it performs the mapping from Tribuo's internal feature ids to
161     * the externally visible feature names.
162     * </p>
163     * @return A list containing all the centroids.
164     */
165    public List<List<Feature>> getCentroids() {
166        List<List<Feature>> output = new ArrayList<>(centroidVectors.length);
167
168        for (int i = 0; i < centroidVectors.length; i++) {
169            List<Feature> features = new ArrayList<>(featureIDMap.size());
170
171            for (VectorTuple v : centroidVectors[i]) {
172                Feature f = new Feature(featureIDMap.get(v.index).getName(),v.value);
173                features.add(f);
174            }
175
176            output.add(features);
177        }
178
179        return output;
180    }
181
182    @Override
183    public Prediction<ClusterID> predict(Example<ClusterID> example) {
184        SGDVector vector;
185        if (example.size() == featureIDMap.size()) {
186            vector = DenseVector.createDenseVector(example, featureIDMap, false);
187        } else {
188            vector = SparseVector.createSparseVector(example, featureIDMap, false);
189        }
190        if (vector.numActiveElements() == 0) {
191            throw new IllegalArgumentException("No features found in Example " + example.toString());
192        }
193        double minDistance = Double.POSITIVE_INFINITY;
194        int id = -1;
195        for (int i = 0; i < centroidVectors.length; i++) {
196            double distance = dist.computeDistance(centroidVectors[i], vector);
197
198            if (distance < minDistance) {
199                minDistance = distance;
200                id = i;
201            }
202        }
203        return new Prediction<>(new ClusterID(id),vector.size(),example);
204    }
205
206    @Override
207    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
208        return Collections.emptyMap();
209    }
210
211    @Override
212    public Optional<Excuse<ClusterID>> getExcuse(Example<ClusterID> example) {
213        return Optional.empty();
214    }
215
216    @Override
217    public ModelProto serialize() {
218        ModelDataCarrier<ClusterID> carrier = createDataCarrier();
219
220        KMeansModelProto.Builder modelBuilder = KMeansModelProto.newBuilder();
221        modelBuilder.setMetadata(carrier.serialize());
222        modelBuilder.setDistance(dist.serialize());
223        for (DenseVector e : centroidVectors) {
224            modelBuilder.addCentroidVectors(e.serialize());
225        }
226
227        ModelProto.Builder builder = ModelProto.newBuilder();
228        builder.setSerializedData(Any.pack(modelBuilder.build()));
229        builder.setClassName(KMeansModel.class.getName());
230        builder.setVersion(CURRENT_VERSION);
231
232        return builder.build();
233    }
234
235    @Override
236    protected KMeansModel copy(String newName, ModelProvenance newProvenance) {
237        DenseVector[] newCentroids = new DenseVector[centroidVectors.length];
238        for (int i = 0; i < centroidVectors.length; i++) {
239            newCentroids[i] = centroidVectors[i].copy();
240        }
241        return new KMeansModel(newName,newProvenance,featureIDMap,outputIDInfo,newCentroids,dist);
242    }
243
244    private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
245        in.defaultReadObject();
246        if (dist == null) {
247            dist = distanceType.getDistanceType().getDistance();
248        }
249    }
250}