001/*
002 * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.rtree;
018
019import com.google.protobuf.Any;
020import com.google.protobuf.InvalidProtocolBufferException;
021import com.oracle.labs.mlrg.olcut.util.Pair;
022import org.tribuo.Example;
023import org.tribuo.Excuse;
024import org.tribuo.ImmutableFeatureMap;
025import org.tribuo.ImmutableOutputInfo;
026import org.tribuo.Model;
027import org.tribuo.Prediction;
028import org.tribuo.common.tree.LeafNode;
029import org.tribuo.common.tree.Node;
030import org.tribuo.common.tree.SplitNode;
031import org.tribuo.common.tree.TreeModel;
032import org.tribuo.common.tree.protos.TreeNodeProto;
033import org.tribuo.impl.ModelDataCarrier;
034import org.tribuo.math.la.SparseVector;
035import org.tribuo.protos.core.ModelProto;
036import org.tribuo.provenance.ModelProvenance;
037import org.tribuo.regression.Regressor;
038import org.tribuo.regression.Regressor.DimensionTuple;
039import org.tribuo.regression.rtree.protos.IndependentRegressionTreeModelProto;
040import org.tribuo.regression.rtree.protos.TreeNodeListProto;
041
042import java.util.ArrayList;
043import java.util.Collections;
044import java.util.Comparator;
045import java.util.HashMap;
046import java.util.HashSet;
047import java.util.LinkedHashSet;
048import java.util.LinkedList;
049import java.util.List;
050import java.util.Map;
051import java.util.Optional;
052import java.util.PriorityQueue;
053import java.util.Queue;
054import java.util.Set;
055
056/**
057 * A {@link Model} wrapped around a list of decision tree root {@link Node}s used
058 * to generate independent predictions for each dimension in a regression.
059 */
060public final class IndependentRegressionTreeModel extends TreeModel<Regressor> {
061    private static final long serialVersionUID = 1L;
062
063    /**
064     * Protobuf serialization version.
065     */
066    public static final int CURRENT_VERSION = 0;
067
068    private final Map<String,Node<Regressor>> roots;
069
070    IndependentRegressionTreeModel(String name, ModelProvenance description,
071                                   ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, boolean generatesProbabilities,
072                                   Map<String,Node<Regressor>> roots) {
073        super(name, description, featureIDMap, outputIDInfo, generatesProbabilities, gatherActiveFeatures(featureIDMap,roots));
074        this.roots = roots;
075    }
076
077    /**
078     * Deserialization factory.
079     * @param version The serialized object version.
080     * @param className The class name.
081     * @param message The serialized data.
082     * @throws InvalidProtocolBufferException If the protobuf could not be parsed from the {@code message}.
083     * @return The deserialized object.
084     */
085    @SuppressWarnings({"unchecked","rawtypes"}) // guarded by getClass to ensure all the output types are the same.
086    public static IndependentRegressionTreeModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
087        if (version < 0 || version > CURRENT_VERSION) {
088            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
089        }
090        IndependentRegressionTreeModelProto proto = message.unpack(IndependentRegressionTreeModelProto.class);
091
092        ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
093        if (!carrier.outputDomain().getOutput(0).getClass().equals(Regressor.class)) {
094            throw new IllegalStateException("Invalid protobuf, output domain is not a regression domain, found " + carrier.outputDomain().getClass());
095        }
096        @SuppressWarnings("unchecked") // guarded by getClass
097        ImmutableOutputInfo<Regressor> outputDomain = (ImmutableOutputInfo<Regressor>) carrier.outputDomain();
098
099        if (proto.getNodesCount() == 0) {
100            throw new IllegalStateException("Invalid protobuf, tree must contain nodes");
101        } else if (proto.getNodesCount() != outputDomain.size()) {
102            throw new IllegalStateException("Invalid protobuf, must have one tree per output dimension, found " + proto.getNodesCount());
103        }
104
105        Map<String,Node<Regressor>> map = new HashMap<>();
106
107        for (Map.Entry<String, TreeNodeListProto> e : proto.getNodesMap().entrySet()) {
108            List<TreeNodeProto> nodeProtos = e.getValue().getNodesList();
109            if (nodeProtos.size() == 0) {
110                throw new IllegalStateException("Invalid protobuf, tree must contain nodes");
111            }
112            List<Node<Regressor>> nodes = deserializeFromProtos(nodeProtos, Regressor.class);
113            map.put(e.getKey(), nodes.get(0));
114        }
115
116        return new IndependentRegressionTreeModel(carrier.name(),carrier.provenance(),carrier.featureDomain(),outputDomain,carrier.generatesProbabilities(),map);
117    }
118
119    private static Map<String,List<String>> gatherActiveFeatures(ImmutableFeatureMap fMap, Map<String,Node<Regressor>> roots) {
120        HashMap<String,List<String>> outputMap = new HashMap<>();
121        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
122            Set<String> activeFeatures = new LinkedHashSet<>();
123
124            Queue<Node<Regressor>> nodeQueue = new LinkedList<>();
125
126            nodeQueue.offer(e.getValue());
127
128            while (!nodeQueue.isEmpty()) {
129                Node<Regressor> node = nodeQueue.poll();
130                if ((node != null) && (!node.isLeaf())) {
131                    SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node;
132                    String featureName = fMap.get(splitNode.getFeatureID()).getName();
133                    activeFeatures.add(featureName);
134                    nodeQueue.offer(splitNode.getGreaterThan());
135                    nodeQueue.offer(splitNode.getLessThanOrEqual());
136                }
137            }
138            outputMap.put(e.getKey(), new ArrayList<>(activeFeatures));
139        }
140        return outputMap;
141    }
142
143    /**
144     * Probes the trees to find the depth.
145     * @return The maximum depth across the trees.
146     */
147    @Override
148    public int getDepth() {
149        int maxDepth = 0;
150        for (Node<Regressor> curRoot : roots.values()) {
151            int thisDepth = computeDepth(0,curRoot);
152            if (maxDepth < thisDepth) {
153                maxDepth = thisDepth;
154            }
155        }
156        return maxDepth;
157    }
158
159    @Override
160    public Prediction<Regressor> predict(Example<Regressor> example) {
161        //
162        // Ensures we handle collisions correctly
163        SparseVector vec = SparseVector.createSparseVector(example,featureIDMap,false);
164        if (vec.numActiveElements() == 0) {
165            throw new IllegalArgumentException("No features found in Example " + example.toString());
166        }
167
168        List<Prediction<Regressor>> predictionList = new ArrayList<>();
169        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
170            Node<Regressor> oldNode = e.getValue();
171            Node<Regressor> curNode = e.getValue();
172
173            while (curNode != null) {
174                oldNode = curNode;
175                curNode = oldNode.getNextNode(vec);
176            }
177
178            //
179            // oldNode must be a LeafNode.
180            predictionList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example));
181        }
182        return combine(predictionList);
183    }
184
185    @Override
186    public Map<String, List<Pair<String,Double>>> getTopFeatures(int n) {
187        int maxFeatures = n < 0 ? featureIDMap.size() : n;
188
189        Map<String, List<Pair<String, Double>>> map = new HashMap<>();
190        Map<String, Integer> featureCounts = new HashMap<>();
191        Queue<Node<Regressor>> nodeQueue = new LinkedList<>();
192
193        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
194            featureCounts.clear();
195            nodeQueue.clear();
196
197            nodeQueue.offer(e.getValue());
198
199            while (!nodeQueue.isEmpty()) {
200                Node<Regressor> node = nodeQueue.poll();
201                if ((node != null) && !node.isLeaf()) {
202                    SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node;
203                    String featureName = featureIDMap.get(splitNode.getFeatureID()).getName();
204                    featureCounts.put(featureName, featureCounts.getOrDefault(featureName, 0) + 1);
205                    nodeQueue.offer(splitNode.getGreaterThan());
206                    nodeQueue.offer(splitNode.getLessThanOrEqual());
207                }
208            }
209
210            Comparator<Pair<String, Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB()));
211            PriorityQueue<Pair<String, Double>> q = new PriorityQueue<>(maxFeatures, comparator);
212
213            for (Map.Entry<String, Integer> featureCount : featureCounts.entrySet()) {
214                Pair<String, Double> cur = new Pair<>(featureCount.getKey(), (double) featureCount.getValue());
215                if (q.size() < maxFeatures) {
216                    q.offer(cur);
217                } else if (comparator.compare(cur, q.peek()) > 0) {
218                    q.poll();
219                    q.offer(cur);
220                }
221            }
222            List<Pair<String, Double>> list = new ArrayList<>();
223            while (q.size() > 0) {
224                list.add(q.poll());
225            }
226            Collections.reverse(list);
227
228            map.put(e.getKey(), list);
229        }
230
231        return map;
232    }
233
234    @Override
235    public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
236        SparseVector vec = SparseVector.createSparseVector(example, featureIDMap, false);
237        if (vec.numActiveElements() == 0) {
238            return Optional.empty();
239        }
240
241        List<String> list = new ArrayList<>();
242        List<Prediction<Regressor>> predList = new ArrayList<>();
243        Map<String, List<Pair<String, Double>>> map = new HashMap<>();
244
245        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
246            list.clear();
247
248            //
249            // Ensures we handle collisions correctly
250            Node<Regressor> oldNode = e.getValue();
251            Node<Regressor> curNode = e.getValue();
252
253            while (curNode != null) {
254                oldNode = curNode;
255                if (oldNode instanceof SplitNode) {
256                    SplitNode<?> node = (SplitNode<?>) curNode;
257                    list.add(featureIDMap.get(node.getFeatureID()).getName());
258                }
259                curNode = oldNode.getNextNode(vec);
260            }
261
262            //
263            // oldNode must be a LeafNode.
264            predList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example));
265
266            List<Pair<String, Double>> pairs = new ArrayList<>();
267            int i = list.size() + 1;
268            for (String s : list) {
269                pairs.add(new Pair<>(s, i + 0.0));
270                i--;
271            }
272
273            map.put(e.getKey(), pairs);
274        }
275        Prediction<Regressor> combinedPrediction = combine(predList);
276
277        return Optional.of(new Excuse<>(example,combinedPrediction,map));
278    }
279
280    @Override
281    protected IndependentRegressionTreeModel copy(String newName, ModelProvenance newProvenance) {
282        Map<String,Node<Regressor>> newRoots = new HashMap<>();
283        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
284            newRoots.put(e.getKey(),e.getValue().copy());
285        }
286        return new IndependentRegressionTreeModel(newName,newProvenance,featureIDMap,outputIDInfo,generatesProbabilities,newRoots);
287    }
288
289    private Prediction<Regressor> combine(List<Prediction<Regressor>> predictions) {
290        DimensionTuple[] tuples = new DimensionTuple[predictions.size()];
291        int numUsed = 0;
292        int i = 0;
293        for (Prediction<Regressor> p : predictions) {
294            if (numUsed < p.getNumActiveFeatures()) {
295                numUsed = p.getNumActiveFeatures();
296            }
297            Regressor output = p.getOutput();
298            if (output instanceof DimensionTuple) {
299                tuples[i] = (DimensionTuple)output;
300            } else {
301                throw new IllegalStateException("All the leaves should contain DimensionTuple not Regressor");
302            }
303            i++;
304        }
305
306        Example<Regressor> example = predictions.get(0).getExample();
307        return new Prediction<>(new Regressor(tuples),numUsed,example);
308    }
309
310    @Override
311    public Set<String> getFeatures() {
312        Set<String> features = new HashSet<>();
313
314        Queue<Node<Regressor>> nodeQueue = new LinkedList<>();
315
316        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
317            nodeQueue.offer(e.getValue());
318
319            while (!nodeQueue.isEmpty()) {
320                Node<Regressor> node = nodeQueue.poll();
321                if ((node != null) && !node.isLeaf()) {
322                    SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node;
323                    features.add(featureIDMap.get(splitNode.getFeatureID()).getName());
324                    nodeQueue.offer(splitNode.getGreaterThan());
325                    nodeQueue.offer(splitNode.getLessThanOrEqual());
326                }
327            }
328        }
329
330        return features;
331    }
332
333    @Override
334    public String toString() {
335        StringBuilder sb = new StringBuilder();
336        for (Map.Entry<String,Node<Regressor>> curRoot : roots.entrySet()) {
337            sb.append("Output '");
338            sb.append(curRoot.getKey());
339            sb.append("' - tree = ");
340            sb.append(curRoot.getValue().toString());
341            sb.append('\n');
342        }
343        return "IndependentTreeModel(description="+provenance.toString()+",\n"+sb.toString()+")";
344    }
345
346    /**
347     * Returns an unmodifiable view on the root node collection.
348     * <p>
349     * The nodes themselves are immutable.
350     * @return The root node collection.
351     */
352    public Map<String,Node<Regressor>> getRoots() {
353        return Collections.unmodifiableMap(roots);
354    }
355
356    /**
357     * Returns null, as this model contains multiple roots, one per regression output dimension.
358     * <p>
359     * Use {@link #getRoots()} instead.
360     * @return null.
361     */
362    @Override
363    public Node<Regressor> getRoot() {
364        return null;
365    }
366
367    @Override
368    public ModelProto serialize() {
369        ModelDataCarrier<Regressor> carrier = createDataCarrier();
370
371        IndependentRegressionTreeModelProto.Builder modelBuilder = IndependentRegressionTreeModelProto.newBuilder();
372        modelBuilder.setMetadata(carrier.serialize());
373        for (Map.Entry<String, Node<Regressor>> e : roots.entrySet()) {
374            TreeNodeListProto listProto = TreeNodeListProto.newBuilder().addAllNodes(serializeToNodes(e.getValue())).build();
375            modelBuilder.putNodes(e.getKey(), listProto);
376        }
377
378        ModelProto.Builder builder = ModelProto.newBuilder();
379        builder.setSerializedData(Any.pack(modelBuilder.build()));
380        builder.setClassName(IndependentRegressionTreeModel.class.getName());
381        builder.setVersion(CURRENT_VERSION);
382
383        return builder.build();
384    }
385}