/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.rtree;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Set;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.common.tree.LeafNode;
import org.tribuo.common.tree.Node;
import org.tribuo.common.tree.SplitNode;
import org.tribuo.common.tree.TreeModel;
import org.tribuo.common.tree.protos.TreeNodeProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.la.SparseVector;
import org.tribuo.protos.core.ModelDataProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.rtree.protos.IndependentRegressionTreeModelProto;
import org.tribuo.regression.rtree.protos.TreeNodeListProto;

public final class IndependentRegressionTreeModel
extends TreeModel<Regressor> {
    private static final long serialVersionUID = 1L;
    public static final int CURRENT_VERSION = 0;
    private final Map<String, Node<Regressor>> roots;

    IndependentRegressionTreeModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, boolean generatesProbabilities, Map<String, Node<Regressor>> roots) {
        super(name, description, featureIDMap, outputIDInfo, generatesProbabilities, IndependentRegressionTreeModel.gatherActiveFeatures(featureIDMap, roots));
        this.roots = roots;
    }

    public static IndependentRegressionTreeModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        IndependentRegressionTreeModelProto proto = (IndependentRegressionTreeModelProto)message.unpack(IndependentRegressionTreeModelProto.class);
        ModelDataCarrier carrier = ModelDataCarrier.deserialize((ModelDataProto)proto.getMetadata());
        if (!carrier.outputDomain().getOutput(0).getClass().equals(Regressor.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a regression domain, found " + carrier.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = carrier.outputDomain();
        if (proto.getNodesCount() == 0) {
            throw new IllegalStateException("Invalid protobuf, tree must contain nodes");
        }
        if (proto.getNodesCount() != outputDomain.size()) {
            throw new IllegalStateException("Invalid protobuf, must have one tree per output dimension, found " + proto.getNodesCount());
        }
        HashMap<String, Node<Regressor>> map = new HashMap<String, Node<Regressor>>();
        for (Map.Entry<String, TreeNodeListProto> e : proto.getNodesMap().entrySet()) {
            List<TreeNodeProto> nodeProtos = e.getValue().getNodesList();
            if (nodeProtos.size() == 0) {
                throw new IllegalStateException("Invalid protobuf, tree must contain nodes");
            }
            List nodes = IndependentRegressionTreeModel.deserializeFromProtos(nodeProtos, Regressor.class);
            map.put(e.getKey(), (Node<Regressor>)((Node)nodes.get(0)));
        }
        return new IndependentRegressionTreeModel(carrier.name(), carrier.provenance(), carrier.featureDomain(), (ImmutableOutputInfo<Regressor>)outputDomain, carrier.generatesProbabilities(), map);
    }

    private static Map<String, List<String>> gatherActiveFeatures(ImmutableFeatureMap fMap, Map<String, Node<Regressor>> roots) {
        HashMap<String, List<String>> outputMap = new HashMap<String, List<String>>();
        for (Map.Entry<String, Node<Regressor>> e : roots.entrySet()) {
            LinkedHashSet<String> activeFeatures = new LinkedHashSet<String>();
            LinkedList<Object> nodeQueue = new LinkedList<Object>();
            nodeQueue.offer(e.getValue());
            while (!nodeQueue.isEmpty()) {
                Node node = (Node)nodeQueue.poll();
                if (node == null || node.isLeaf()) continue;
                SplitNode splitNode = (SplitNode)node;
                String featureName = fMap.get(splitNode.getFeatureID()).getName();
                activeFeatures.add(featureName);
                nodeQueue.offer(splitNode.getGreaterThan());
                nodeQueue.offer(splitNode.getLessThanOrEqual());
            }
            outputMap.put(e.getKey(), new ArrayList(activeFeatures));
        }
        return outputMap;
    }

    public int getDepth() {
        int maxDepth = 0;
        for (Node<Regressor> curRoot : this.roots.values()) {
            int thisDepth = IndependentRegressionTreeModel.computeDepth((int)0, curRoot);
            if (maxDepth >= thisDepth) continue;
            maxDepth = thisDepth;
        }
        return maxDepth;
    }

    public Prediction<Regressor> predict(Example<Regressor> example) {
        SparseVector vec = SparseVector.createSparseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)false);
        if (vec.numActiveElements() == 0) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        ArrayList<Prediction<Regressor>> predictionList = new ArrayList<Prediction<Regressor>>();
        for (Map.Entry<String, Node<Regressor>> e : this.roots.entrySet()) {
            Node<Regressor> oldNode = e.getValue();
            Node curNode = e.getValue();
            while (curNode != null) {
                oldNode = curNode;
                curNode = oldNode.getNextNode(vec);
            }
            predictionList.add((Prediction<Regressor>)((LeafNode)oldNode).getPrediction(vec.numActiveElements(), example));
        }
        return this.combine(predictionList);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        int maxFeatures = n < 0 ? this.featureIDMap.size() : n;
        HashMap<String, List<Pair<String, Double>>> map = new HashMap<String, List<Pair<String, Double>>>();
        HashMap<Object, Integer> featureCounts = new HashMap<Object, Integer>();
        LinkedList<Object> nodeQueue = new LinkedList<Object>();
        for (Map.Entry<String, Node<Regressor>> e : this.roots.entrySet()) {
            featureCounts.clear();
            nodeQueue.clear();
            nodeQueue.offer(e.getValue());
            while (!nodeQueue.isEmpty()) {
                Node node = (Node)nodeQueue.poll();
                if (node == null || node.isLeaf()) continue;
                SplitNode splitNode = (SplitNode)node;
                String featureName = this.featureIDMap.get(splitNode.getFeatureID()).getName();
                featureCounts.put(featureName, featureCounts.getOrDefault(featureName, 0) + 1);
                nodeQueue.offer(splitNode.getGreaterThan());
                nodeQueue.offer(splitNode.getLessThanOrEqual());
            }
            Comparator<Pair> comparator = Comparator.comparingDouble(p -> Math.abs((Double)p.getB()));
            PriorityQueue<Pair> q = new PriorityQueue<Pair>(maxFeatures, comparator);
            for (Map.Entry featureCount : featureCounts.entrySet()) {
                Pair cur = new Pair((Object)((String)featureCount.getKey()), (Object)((Integer)featureCount.getValue()));
                if (q.size() < maxFeatures) {
                    q.offer(cur);
                    continue;
                }
                if (comparator.compare(cur, q.peek()) <= 0) continue;
                q.poll();
                q.offer(cur);
            }
            ArrayList<Pair> list = new ArrayList<Pair>();
            while (q.size() > 0) {
                list.add(q.poll());
            }
            Collections.reverse(list);
            map.put(e.getKey(), list);
        }
        return map;
    }

    public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
        SparseVector vec = SparseVector.createSparseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)false);
        if (vec.numActiveElements() == 0) {
            return Optional.empty();
        }
        ArrayList<String> list = new ArrayList<String>();
        ArrayList<Prediction<Regressor>> predList = new ArrayList<Prediction<Regressor>>();
        HashMap map = new HashMap();
        for (Map.Entry<String, Node<Regressor>> e : this.roots.entrySet()) {
            list.clear();
            Node oldNode = e.getValue();
            Node curNode = e.getValue();
            while (curNode != null) {
                oldNode = curNode;
                if (oldNode instanceof SplitNode) {
                    SplitNode node = (SplitNode)curNode;
                    list.add(this.featureIDMap.get(node.getFeatureID()).getName());
                }
                curNode = oldNode.getNextNode(vec);
            }
            predList.add((Prediction<Regressor>)((LeafNode)oldNode).getPrediction(vec.numActiveElements(), example));
            ArrayList<Pair> pairs = new ArrayList<Pair>();
            int i = list.size() + 1;
            for (String s : list) {
                pairs.add(new Pair((Object)s, (Object)((double)i + 0.0)));
                --i;
            }
            map.put(e.getKey(), pairs);
        }
        Prediction<Regressor> combinedPrediction = this.combine(predList);
        return Optional.of(new Excuse(example, combinedPrediction, map));
    }

    protected IndependentRegressionTreeModel copy(String newName, ModelProvenance newProvenance) {
        HashMap<String, Node<Regressor>> newRoots = new HashMap<String, Node<Regressor>>();
        for (Map.Entry<String, Node<Regressor>> e : this.roots.entrySet()) {
            newRoots.put(e.getKey(), (Node<Regressor>)e.getValue().copy());
        }
        return new IndependentRegressionTreeModel(newName, newProvenance, this.featureIDMap, (ImmutableOutputInfo<Regressor>)this.outputIDInfo, this.generatesProbabilities, newRoots);
    }

    private Prediction<Regressor> combine(List<Prediction<Regressor>> predictions) {
        Regressor.DimensionTuple[] tuples = new Regressor.DimensionTuple[predictions.size()];
        int numUsed = 0;
        int i = 0;
        for (Prediction<Regressor> p : predictions) {
            Regressor output;
            if (numUsed < p.getNumActiveFeatures()) {
                numUsed = p.getNumActiveFeatures();
            }
            if (!((output = (Regressor)p.getOutput()) instanceof Regressor.DimensionTuple)) {
                throw new IllegalStateException("All the leaves should contain DimensionTuple not Regressor");
            }
            tuples[i] = (Regressor.DimensionTuple)output;
            ++i;
        }
        Example example = predictions.get(0).getExample();
        return new Prediction((Output)new Regressor(tuples), numUsed, example);
    }

    public Set<String> getFeatures() {
        HashSet<String> features = new HashSet<String>();
        LinkedList<Object> nodeQueue = new LinkedList<Object>();
        for (Map.Entry<String, Node<Regressor>> e : this.roots.entrySet()) {
            nodeQueue.offer(e.getValue());
            while (!nodeQueue.isEmpty()) {
                Node node = (Node)nodeQueue.poll();
                if (node == null || node.isLeaf()) continue;
                SplitNode splitNode = (SplitNode)node;
                features.add(this.featureIDMap.get(splitNode.getFeatureID()).getName());
                nodeQueue.offer(splitNode.getGreaterThan());
                nodeQueue.offer(splitNode.getLessThanOrEqual());
            }
        }
        return features;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (Map.Entry<String, Node<Regressor>> curRoot : this.roots.entrySet()) {
            sb.append("Output '");
            sb.append(curRoot.getKey());
            sb.append("' - tree = ");
            sb.append(curRoot.getValue().toString());
            sb.append('\n');
        }
        return "IndependentTreeModel(description=" + this.provenance.toString() + ",\n" + sb.toString() + ")";
    }

    public Map<String, Node<Regressor>> getRoots() {
        return Collections.unmodifiableMap(this.roots);
    }

    public Node<Regressor> getRoot() {
        return null;
    }

    public ModelProto serialize() {
        ModelDataCarrier carrier = this.createDataCarrier();
        IndependentRegressionTreeModelProto.Builder modelBuilder = IndependentRegressionTreeModelProto.newBuilder();
        modelBuilder.setMetadata(carrier.serialize());
        for (Map.Entry<String, Node<Regressor>> e : this.roots.entrySet()) {
            TreeNodeListProto listProto = TreeNodeListProto.newBuilder().addAllNodes(this.serializeToNodes(e.getValue())).build();
            modelBuilder.putNodes(e.getKey(), listProto);
        }
        ModelProto.Builder builder = ModelProto.newBuilder();
        builder.setSerializedData(Any.pack((Message)modelBuilder.build()));
        builder.setClassName(IndependentRegressionTreeModel.class.getName());
        builder.setVersion(0);
        return builder.build();
    }
}

