001/* 002 * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree; 018 019import com.google.protobuf.Any; 020import com.google.protobuf.InvalidProtocolBufferException; 021import com.oracle.labs.mlrg.olcut.util.Pair; 022import org.tribuo.Example; 023import org.tribuo.Excuse; 024import org.tribuo.ImmutableFeatureMap; 025import org.tribuo.ImmutableOutputInfo; 026import org.tribuo.Model; 027import org.tribuo.Prediction; 028import org.tribuo.common.tree.LeafNode; 029import org.tribuo.common.tree.Node; 030import org.tribuo.common.tree.SplitNode; 031import org.tribuo.common.tree.TreeModel; 032import org.tribuo.common.tree.protos.TreeNodeProto; 033import org.tribuo.impl.ModelDataCarrier; 034import org.tribuo.math.la.SparseVector; 035import org.tribuo.protos.core.ModelProto; 036import org.tribuo.provenance.ModelProvenance; 037import org.tribuo.regression.Regressor; 038import org.tribuo.regression.Regressor.DimensionTuple; 039import org.tribuo.regression.rtree.protos.IndependentRegressionTreeModelProto; 040import org.tribuo.regression.rtree.protos.TreeNodeListProto; 041 042import java.util.ArrayList; 043import java.util.Collections; 044import java.util.Comparator; 045import java.util.HashMap; 046import java.util.HashSet; 047import java.util.LinkedHashSet; 048import java.util.LinkedList; 049import java.util.List; 050import java.util.Map; 051import java.util.Optional; 052import java.util.PriorityQueue; 053import java.util.Queue; 054import java.util.Set; 055 056/** 057 * A {@link Model} wrapped around a list of decision tree root {@link Node}s used 058 * to generate independent predictions for each dimension in a regression. 059 */ 060public final class IndependentRegressionTreeModel extends TreeModel<Regressor> { 061 private static final long serialVersionUID = 1L; 062 063 /** 064 * Protobuf serialization version. 065 */ 066 public static final int CURRENT_VERSION = 0; 067 068 private final Map<String,Node<Regressor>> roots; 069 070 IndependentRegressionTreeModel(String name, ModelProvenance description, 071 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, boolean generatesProbabilities, 072 Map<String,Node<Regressor>> roots) { 073 super(name, description, featureIDMap, outputIDInfo, generatesProbabilities, gatherActiveFeatures(featureIDMap,roots)); 074 this.roots = roots; 075 } 076 077 /** 078 * Deserialization factory. 079 * @param version The serialized object version. 080 * @param className The class name. 081 * @param message The serialized data. 082 * @throws InvalidProtocolBufferException If the protobuf could not be parsed from the {@code message}. 083 * @return The deserialized object. 084 */ 085 @SuppressWarnings({"unchecked","rawtypes"}) // guarded by getClass to ensure all the output types are the same. 086 public static IndependentRegressionTreeModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException { 087 if (version < 0 || version > CURRENT_VERSION) { 088 throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION); 089 } 090 IndependentRegressionTreeModelProto proto = message.unpack(IndependentRegressionTreeModelProto.class); 091 092 ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata()); 093 if (!carrier.outputDomain().getOutput(0).getClass().equals(Regressor.class)) { 094 throw new IllegalStateException("Invalid protobuf, output domain is not a regression domain, found " + carrier.outputDomain().getClass()); 095 } 096 @SuppressWarnings("unchecked") // guarded by getClass 097 ImmutableOutputInfo<Regressor> outputDomain = (ImmutableOutputInfo<Regressor>) carrier.outputDomain(); 098 099 if (proto.getNodesCount() == 0) { 100 throw new IllegalStateException("Invalid protobuf, tree must contain nodes"); 101 } else if (proto.getNodesCount() != outputDomain.size()) { 102 throw new IllegalStateException("Invalid protobuf, must have one tree per output dimension, found " + proto.getNodesCount()); 103 } 104 105 Map<String,Node<Regressor>> map = new HashMap<>(); 106 107 for (Map.Entry<String, TreeNodeListProto> e : proto.getNodesMap().entrySet()) { 108 List<TreeNodeProto> nodeProtos = e.getValue().getNodesList(); 109 if (nodeProtos.size() == 0) { 110 throw new IllegalStateException("Invalid protobuf, tree must contain nodes"); 111 } 112 List<Node<Regressor>> nodes = deserializeFromProtos(nodeProtos, Regressor.class); 113 map.put(e.getKey(), nodes.get(0)); 114 } 115 116 return new IndependentRegressionTreeModel(carrier.name(),carrier.provenance(),carrier.featureDomain(),outputDomain,carrier.generatesProbabilities(),map); 117 } 118 119 private static Map<String,List<String>> gatherActiveFeatures(ImmutableFeatureMap fMap, Map<String,Node<Regressor>> roots) { 120 HashMap<String,List<String>> outputMap = new HashMap<>(); 121 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 122 Set<String> activeFeatures = new LinkedHashSet<>(); 123 124 Queue<Node<Regressor>> nodeQueue = new LinkedList<>(); 125 126 nodeQueue.offer(e.getValue()); 127 128 while (!nodeQueue.isEmpty()) { 129 Node<Regressor> node = nodeQueue.poll(); 130 if ((node != null) && (!node.isLeaf())) { 131 SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node; 132 String featureName = fMap.get(splitNode.getFeatureID()).getName(); 133 activeFeatures.add(featureName); 134 nodeQueue.offer(splitNode.getGreaterThan()); 135 nodeQueue.offer(splitNode.getLessThanOrEqual()); 136 } 137 } 138 outputMap.put(e.getKey(), new ArrayList<>(activeFeatures)); 139 } 140 return outputMap; 141 } 142 143 /** 144 * Probes the trees to find the depth. 145 * @return The maximum depth across the trees. 146 */ 147 @Override 148 public int getDepth() { 149 int maxDepth = 0; 150 for (Node<Regressor> curRoot : roots.values()) { 151 int thisDepth = computeDepth(0,curRoot); 152 if (maxDepth < thisDepth) { 153 maxDepth = thisDepth; 154 } 155 } 156 return maxDepth; 157 } 158 159 @Override 160 public Prediction<Regressor> predict(Example<Regressor> example) { 161 // 162 // Ensures we handle collisions correctly 163 SparseVector vec = SparseVector.createSparseVector(example,featureIDMap,false); 164 if (vec.numActiveElements() == 0) { 165 throw new IllegalArgumentException("No features found in Example " + example.toString()); 166 } 167 168 List<Prediction<Regressor>> predictionList = new ArrayList<>(); 169 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 170 Node<Regressor> oldNode = e.getValue(); 171 Node<Regressor> curNode = e.getValue(); 172 173 while (curNode != null) { 174 oldNode = curNode; 175 curNode = oldNode.getNextNode(vec); 176 } 177 178 // 179 // oldNode must be a LeafNode. 180 predictionList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example)); 181 } 182 return combine(predictionList); 183 } 184 185 @Override 186 public Map<String, List<Pair<String,Double>>> getTopFeatures(int n) { 187 int maxFeatures = n < 0 ? featureIDMap.size() : n; 188 189 Map<String, List<Pair<String, Double>>> map = new HashMap<>(); 190 Map<String, Integer> featureCounts = new HashMap<>(); 191 Queue<Node<Regressor>> nodeQueue = new LinkedList<>(); 192 193 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 194 featureCounts.clear(); 195 nodeQueue.clear(); 196 197 nodeQueue.offer(e.getValue()); 198 199 while (!nodeQueue.isEmpty()) { 200 Node<Regressor> node = nodeQueue.poll(); 201 if ((node != null) && !node.isLeaf()) { 202 SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node; 203 String featureName = featureIDMap.get(splitNode.getFeatureID()).getName(); 204 featureCounts.put(featureName, featureCounts.getOrDefault(featureName, 0) + 1); 205 nodeQueue.offer(splitNode.getGreaterThan()); 206 nodeQueue.offer(splitNode.getLessThanOrEqual()); 207 } 208 } 209 210 Comparator<Pair<String, Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB())); 211 PriorityQueue<Pair<String, Double>> q = new PriorityQueue<>(maxFeatures, comparator); 212 213 for (Map.Entry<String, Integer> featureCount : featureCounts.entrySet()) { 214 Pair<String, Double> cur = new Pair<>(featureCount.getKey(), (double) featureCount.getValue()); 215 if (q.size() < maxFeatures) { 216 q.offer(cur); 217 } else if (comparator.compare(cur, q.peek()) > 0) { 218 q.poll(); 219 q.offer(cur); 220 } 221 } 222 List<Pair<String, Double>> list = new ArrayList<>(); 223 while (q.size() > 0) { 224 list.add(q.poll()); 225 } 226 Collections.reverse(list); 227 228 map.put(e.getKey(), list); 229 } 230 231 return map; 232 } 233 234 @Override 235 public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) { 236 SparseVector vec = SparseVector.createSparseVector(example, featureIDMap, false); 237 if (vec.numActiveElements() == 0) { 238 return Optional.empty(); 239 } 240 241 List<String> list = new ArrayList<>(); 242 List<Prediction<Regressor>> predList = new ArrayList<>(); 243 Map<String, List<Pair<String, Double>>> map = new HashMap<>(); 244 245 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 246 list.clear(); 247 248 // 249 // Ensures we handle collisions correctly 250 Node<Regressor> oldNode = e.getValue(); 251 Node<Regressor> curNode = e.getValue(); 252 253 while (curNode != null) { 254 oldNode = curNode; 255 if (oldNode instanceof SplitNode) { 256 SplitNode<?> node = (SplitNode<?>) curNode; 257 list.add(featureIDMap.get(node.getFeatureID()).getName()); 258 } 259 curNode = oldNode.getNextNode(vec); 260 } 261 262 // 263 // oldNode must be a LeafNode. 264 predList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example)); 265 266 List<Pair<String, Double>> pairs = new ArrayList<>(); 267 int i = list.size() + 1; 268 for (String s : list) { 269 pairs.add(new Pair<>(s, i + 0.0)); 270 i--; 271 } 272 273 map.put(e.getKey(), pairs); 274 } 275 Prediction<Regressor> combinedPrediction = combine(predList); 276 277 return Optional.of(new Excuse<>(example,combinedPrediction,map)); 278 } 279 280 @Override 281 protected IndependentRegressionTreeModel copy(String newName, ModelProvenance newProvenance) { 282 Map<String,Node<Regressor>> newRoots = new HashMap<>(); 283 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 284 newRoots.put(e.getKey(),e.getValue().copy()); 285 } 286 return new IndependentRegressionTreeModel(newName,newProvenance,featureIDMap,outputIDInfo,generatesProbabilities,newRoots); 287 } 288 289 private Prediction<Regressor> combine(List<Prediction<Regressor>> predictions) { 290 DimensionTuple[] tuples = new DimensionTuple[predictions.size()]; 291 int numUsed = 0; 292 int i = 0; 293 for (Prediction<Regressor> p : predictions) { 294 if (numUsed < p.getNumActiveFeatures()) { 295 numUsed = p.getNumActiveFeatures(); 296 } 297 Regressor output = p.getOutput(); 298 if (output instanceof DimensionTuple) { 299 tuples[i] = (DimensionTuple)output; 300 } else { 301 throw new IllegalStateException("All the leaves should contain DimensionTuple not Regressor"); 302 } 303 i++; 304 } 305 306 Example<Regressor> example = predictions.get(0).getExample(); 307 return new Prediction<>(new Regressor(tuples),numUsed,example); 308 } 309 310 @Override 311 public Set<String> getFeatures() { 312 Set<String> features = new HashSet<>(); 313 314 Queue<Node<Regressor>> nodeQueue = new LinkedList<>(); 315 316 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 317 nodeQueue.offer(e.getValue()); 318 319 while (!nodeQueue.isEmpty()) { 320 Node<Regressor> node = nodeQueue.poll(); 321 if ((node != null) && !node.isLeaf()) { 322 SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node; 323 features.add(featureIDMap.get(splitNode.getFeatureID()).getName()); 324 nodeQueue.offer(splitNode.getGreaterThan()); 325 nodeQueue.offer(splitNode.getLessThanOrEqual()); 326 } 327 } 328 } 329 330 return features; 331 } 332 333 @Override 334 public String toString() { 335 StringBuilder sb = new StringBuilder(); 336 for (Map.Entry<String,Node<Regressor>> curRoot : roots.entrySet()) { 337 sb.append("Output '"); 338 sb.append(curRoot.getKey()); 339 sb.append("' - tree = "); 340 sb.append(curRoot.getValue().toString()); 341 sb.append('\n'); 342 } 343 return "IndependentTreeModel(description="+provenance.toString()+",\n"+sb.toString()+")"; 344 } 345 346 /** 347 * Returns an unmodifiable view on the root node collection. 348 * <p> 349 * The nodes themselves are immutable. 350 * @return The root node collection. 351 */ 352 public Map<String,Node<Regressor>> getRoots() { 353 return Collections.unmodifiableMap(roots); 354 } 355 356 /** 357 * Returns null, as this model contains multiple roots, one per regression output dimension. 358 * <p> 359 * Use {@link #getRoots()} instead. 360 * @return null. 361 */ 362 @Override 363 public Node<Regressor> getRoot() { 364 return null; 365 } 366 367 @Override 368 public ModelProto serialize() { 369 ModelDataCarrier<Regressor> carrier = createDataCarrier(); 370 371 IndependentRegressionTreeModelProto.Builder modelBuilder = IndependentRegressionTreeModelProto.newBuilder(); 372 modelBuilder.setMetadata(carrier.serialize()); 373 for (Map.Entry<String, Node<Regressor>> e : roots.entrySet()) { 374 TreeNodeListProto listProto = TreeNodeListProto.newBuilder().addAllNodes(serializeToNodes(e.getValue())).build(); 375 modelBuilder.putNodes(e.getKey(), listProto); 376 } 377 378 ModelProto.Builder builder = ModelProto.newBuilder(); 379 builder.setSerializedData(Any.pack(modelBuilder.build())); 380 builder.setClassName(IndependentRegressionTreeModel.class.getName()); 381 builder.setVersion(CURRENT_VERSION); 382 383 return builder.build(); 384 } 385}