/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.common.tree;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Set;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.SparseModel;
import org.tribuo.common.tree.LeafNode;
import org.tribuo.common.tree.Node;
import org.tribuo.common.tree.SplitNode;
import org.tribuo.common.tree.protos.LeafNodeProto;
import org.tribuo.common.tree.protos.SplitNodeProto;
import org.tribuo.common.tree.protos.TreeModelProto;
import org.tribuo.common.tree.protos.TreeNodeProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.la.SparseVector;
import org.tribuo.protos.core.ModelDataProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;

public class TreeModel<T extends Output<T>>
extends SparseModel<T> {
    private static final long serialVersionUID = 3L;
    public static final int CURRENT_VERSION = 0;
    private final Node<T> root;

    TreeModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, Node<T> root) {
        super(name, description, featureIDMap, outputIDInfo, generatesProbabilities, TreeModel.gatherActiveFeatures(featureIDMap, root));
        this.root = root;
    }

    protected TreeModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, Map<String, List<String>> activeFeatures) {
        super(name, description, featureIDMap, outputIDInfo, generatesProbabilities, activeFeatures);
        this.root = null;
    }

    public static TreeModel<?> deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        TreeModelProto proto = (TreeModelProto)message.unpack(TreeModelProto.class);
        ModelDataCarrier carrier = ModelDataCarrier.deserialize((ModelDataProto)proto.getMetadata());
        Class<?> outputClass = carrier.outputDomain().getOutput(0).getClass();
        if (proto.getNodesCount() == 0) {
            throw new IllegalStateException("Invalid protobuf, tree must contain nodes");
        }
        List<TreeNodeProto> nodeProtos = proto.getNodesList();
        List<Node<?>> nodes = TreeModel.deserializeFromProtos(nodeProtos, outputClass);
        return new TreeModel(carrier.name(), carrier.provenance(), carrier.featureDomain(), carrier.outputDomain(), carrier.generatesProbabilities(), nodes.get(0));
    }

    private static Node<?> deserializeNodeProto(TreeNodeProto proto) throws InvalidProtocolBufferException {
        int version = proto.getVersion();
        String className = proto.getClassName();
        Any message = proto.getSerializedData();
        if (message.is(SplitNodeProto.class)) {
            SplitNodeProto splitProto = (SplitNodeProto)message.unpack(SplitNodeProto.class);
            return new SplitNode.SplitNodeBuilder(splitProto);
        }
        if (message.is(LeafNodeProto.class)) {
            LeafNodeProto leafProto = (LeafNodeProto)message.unpack(LeafNodeProto.class);
            return new LeafNode.LeafNodeBuilder(leafProto);
        }
        throw new IllegalStateException("Invalid protobuf, expected leaf or split node, found " + message.getTypeUrl());
    }

    protected static <U extends Output<U>> List<Node<U>> deserializeFromProtos(List<TreeNodeProto> nodeProtos, Class<U> outputClass) throws InvalidProtocolBufferException {
        ArrayList<Node<U>> nodes = new ArrayList<Node<U>>(nodeProtos.size());
        for (TreeNodeProto treeNodeProto : nodeProtos) {
            Node<?> node = TreeModel.deserializeNodeProto(treeNodeProto);
            nodes.add(node);
        }
        ArrayDeque<Node> nodeQueue = new ArrayDeque<Node>();
        for (Node node : nodes) {
            if (!(node instanceof LeafNode.LeafNodeBuilder)) continue;
            nodeQueue.offer(node);
        }
        while (!nodeQueue.isEmpty()) {
            int parentIdx;
            NodeBuilder builder;
            Node node = (Node)nodeQueue.poll();
            int n = -1;
            Node parent = null;
            Node builtNode = null;
            if (node instanceof LeafNode.LeafNodeBuilder) {
                builder = (LeafNode.LeafNodeBuilder)node;
                Node leaf = ((LeafNode.LeafNodeBuilder)builder).build();
                nodes.set(((LeafNode.LeafNodeBuilder)builder).getCurIdx(), leaf);
                builtNode = leaf;
                n = ((LeafNode.LeafNodeBuilder)builder).getCurIdx();
                parentIdx = ((LeafNode.LeafNodeBuilder)builder).getParentIdx();
                if (parentIdx != -1) {
                    parent = (Node)nodes.get(parentIdx);
                }
            } else if (node instanceof SplitNode.SplitNodeBuilder) {
                builder = (SplitNode.SplitNodeBuilder)node;
                Node split = ((SplitNode.SplitNodeBuilder)builder).build();
                nodes.set(((SplitNode.SplitNodeBuilder)builder).getCurIdx(), split);
                builtNode = split;
                n = ((SplitNode.SplitNodeBuilder)builder).getCurIdx();
                parentIdx = ((SplitNode.SplitNodeBuilder)builder).getParentIdx();
                if (parentIdx != -1) {
                    parent = (Node)nodes.get(parentIdx);
                }
            } else {
                throw new IllegalStateException("Invalid protobuf, found a constructed node was added to the build queue, found " + node.getClass());
            }
            if (parent instanceof SplitNode.SplitNodeBuilder) {
                SplitNode.SplitNodeBuilder splitBuilder = (SplitNode.SplitNodeBuilder)parent;
                if (n == splitBuilder.getGreaterThanIdx()) {
                    splitBuilder.setGreaterThan(builtNode);
                } else if (n == splitBuilder.getLessThanOrEqualIdx()) {
                    splitBuilder.setLessThanOrEqual(builtNode);
                } else {
                    throw new IllegalStateException("Invalid protobuf, found a child node which didn't map into a parent");
                }
                if (!splitBuilder.canBuild()) continue;
                nodeQueue.offer(splitBuilder);
                continue;
            }
            if (parent == null) continue;
            throw new IllegalStateException("Invalid protobuf, found a " + parent.getClass() + " when a SplitNodeBuilder was expected");
        }
        for (Node node : nodes) {
            Object cur;
            if (!(node instanceof SplitNode) && !(node instanceof LeafNode)) {
                throw new IllegalStateException("Invalid protobuf, found unbuilt node, " + node);
            }
            if (!(node instanceof LeafNode) || outputClass.isAssignableFrom((cur = ((LeafNode)node).getOutput()).getClass())) continue;
            throw new IllegalStateException("Invalid protobuf, node output did not match output domain, found " + cur.getClass() + ", expected " + outputClass);
        }
        return nodes;
    }

    private static <T extends Output<T>> Map<String, List<String>> gatherActiveFeatures(ImmutableFeatureMap fMap, Node<T> root) {
        LinkedHashSet<String> activeFeatures = new LinkedHashSet<String>();
        LinkedList nodeQueue = new LinkedList();
        nodeQueue.offer(root);
        while (!nodeQueue.isEmpty()) {
            Node node = (Node)nodeQueue.poll();
            if (node == null || node.isLeaf()) continue;
            SplitNode splitNode = (SplitNode)node;
            String featureName = fMap.get(splitNode.getFeatureID()).getName();
            activeFeatures.add(featureName);
            nodeQueue.offer(splitNode.getGreaterThan());
            nodeQueue.offer(splitNode.getLessThanOrEqual());
        }
        return Collections.singletonMap("ALL_OUTPUTS", new ArrayList(activeFeatures));
    }

    public int getDepth() {
        return TreeModel.computeDepth(0, this.root);
    }

    protected static <T extends Output<T>> int computeDepth(int initialDepth, Node<T> root) {
        int maxDepth = initialDepth;
        LinkedList<Pair> nodeQueue = new LinkedList<Pair>();
        nodeQueue.offer(new Pair((Object)initialDepth, root));
        while (!nodeQueue.isEmpty()) {
            Pair nodePair = (Pair)nodeQueue.poll();
            int curDepth = (Integer)nodePair.getA() + 1;
            Node node = (Node)nodePair.getB();
            if (node == null || node.isLeaf()) continue;
            SplitNode splitNode = (SplitNode)node;
            Node greaterThan = splitNode.getGreaterThan();
            Node lessThan = splitNode.getLessThanOrEqual();
            if (greaterThan instanceof LeafNode) {
                if (maxDepth < curDepth) {
                    maxDepth = curDepth;
                }
            } else {
                nodeQueue.offer(new Pair((Object)curDepth, greaterThan));
            }
            if (lessThan instanceof LeafNode) {
                if (maxDepth >= curDepth) continue;
                maxDepth = curDepth;
                continue;
            }
            nodeQueue.offer(new Pair((Object)curDepth, lessThan));
        }
        return maxDepth;
    }

    public Prediction<T> predict(Example<T> example) {
        SparseVector vec = SparseVector.createSparseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)false);
        if (vec.numActiveElements() == 0) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        Node<T> oldNode = this.root;
        Node<T> curNode = this.root;
        while (curNode != null) {
            oldNode = curNode;
            curNode = oldNode.getNextNode(vec);
        }
        return ((LeafNode)oldNode).getPrediction(vec.numActiveElements(), example);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        int maxFeatures = n < 0 ? this.featureIDMap.size() : n;
        HashMap<Object, Integer> featureCounts = new HashMap<Object, Integer>();
        LinkedList nodeQueue = new LinkedList();
        nodeQueue.offer(this.root);
        while (!nodeQueue.isEmpty()) {
            Node node = (Node)nodeQueue.poll();
            if (node == null || node.isLeaf()) continue;
            SplitNode splitNode = (SplitNode)node;
            String featureName = this.featureIDMap.get(splitNode.getFeatureID()).getName();
            featureCounts.put(featureName, featureCounts.getOrDefault(featureName, 0) + 1);
            nodeQueue.offer(splitNode.getGreaterThan());
            nodeQueue.offer(splitNode.getLessThanOrEqual());
        }
        Comparator<Pair> comparator = Comparator.comparingDouble(p -> Math.abs((Double)p.getB()));
        PriorityQueue<Pair> q = new PriorityQueue<Pair>(maxFeatures, comparator);
        for (Map.Entry e : featureCounts.entrySet()) {
            Pair cur = new Pair((Object)((String)e.getKey()), (Object)((Integer)e.getValue()));
            if (q.size() < maxFeatures) {
                q.offer(cur);
                continue;
            }
            if (comparator.compare(cur, q.peek()) <= 0) continue;
            q.poll();
            q.offer(cur);
        }
        ArrayList<Pair> list = new ArrayList<Pair>();
        while (q.size() > 0) {
            list.add(q.poll());
        }
        Collections.reverse(list);
        HashMap<String, List<Pair<String, Double>>> map = new HashMap<String, List<Pair<String, Double>>>();
        map.put("ALL_OUTPUTS", list);
        return map;
    }

    public Optional<Excuse<T>> getExcuse(Example<T> example) {
        ArrayList<String> list = new ArrayList<String>();
        SparseVector vec = SparseVector.createSparseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)false);
        Node<T> oldNode = this.root;
        Node<T> curNode = this.root;
        while (curNode != null) {
            oldNode = curNode;
            if (oldNode instanceof SplitNode) {
                SplitNode node = (SplitNode)curNode;
                list.add(this.featureIDMap.get(node.getFeatureID()).getName());
            }
            curNode = oldNode.getNextNode(vec);
        }
        Prediction<T> pred = ((LeafNode)oldNode).getPrediction(vec.numActiveElements(), example);
        ArrayList<Pair> pairs = new ArrayList<Pair>();
        int i = list.size() + 1;
        for (String s : list) {
            pairs.add(new Pair((Object)s, (Object)((double)i + 0.0)));
            --i;
        }
        HashMap<String, ArrayList<Pair>> map = new HashMap<String, ArrayList<Pair>>();
        map.put("ALL_OUTPUTS", pairs);
        return Optional.of(new Excuse(example, pred, map));
    }

    protected TreeModel<T> copy(String newName, ModelProvenance newProvenance) {
        return new TreeModel<T>(newName, newProvenance, this.featureIDMap, this.outputIDInfo, this.generatesProbabilities, this.root.copy());
    }

    public Set<String> getFeatures() {
        HashSet<String> features = new HashSet<String>();
        LinkedList nodeQueue = new LinkedList();
        nodeQueue.offer(this.root);
        while (!nodeQueue.isEmpty()) {
            Node node = (Node)nodeQueue.poll();
            if (node == null || node.isLeaf()) continue;
            SplitNode splitNode = (SplitNode)node;
            features.add(this.featureIDMap.get(splitNode.getFeatureID()).getName());
            nodeQueue.offer(splitNode.getGreaterThan());
            nodeQueue.offer(splitNode.getLessThanOrEqual());
        }
        return features;
    }

    public int countNodes(Node<T> root) {
        LinkedList nodeQueue = new LinkedList();
        int counter = 0;
        nodeQueue.offer(root);
        while (!nodeQueue.isEmpty()) {
            Node node = (Node)nodeQueue.poll();
            if (node == null) continue;
            ++counter;
            if (node.isLeaf()) continue;
            SplitNode splitNode = (SplitNode)node;
            nodeQueue.offer(splitNode.getGreaterThan());
            nodeQueue.offer(splitNode.getLessThanOrEqual());
        }
        return counter;
    }

    public String toString() {
        return "TreeModel(description=" + this.provenance.toString() + ",\n\t\ttree=" + this.root.toString() + ")";
    }

    public Node<T> getRoot() {
        return this.root;
    }

    public ModelProto serialize() {
        ModelDataCarrier carrier = this.createDataCarrier();
        TreeModelProto.Builder modelBuilder = TreeModelProto.newBuilder();
        modelBuilder.setMetadata(carrier.serialize());
        modelBuilder.addAllNodes(this.serializeToNodes(this.root));
        ModelProto.Builder builder = ModelProto.newBuilder();
        builder.setSerializedData(Any.pack((Message)modelBuilder.build()));
        builder.setClassName(TreeModel.class.getName());
        builder.setVersion(0);
        return builder.build();
    }

    protected List<TreeNodeProto> serializeToNodes(Node<T> root) {
        int numNodes = this.countNodes(root);
        TreeNodeProto[] protos = new TreeNodeProto[numNodes];
        int counter = 0;
        ArrayDeque nodeQueue = new ArrayDeque();
        nodeQueue.offer(new SerializationState<T>(-1, counter, root));
        while (!nodeQueue.isEmpty()) {
            Node node;
            SerializationState state = (SerializationState)nodeQueue.poll();
            if (state.node instanceof SplitNode) {
                TreeNodeProto proto;
                node = (SplitNode)state.node;
                int greaterIdx = ++counter;
                int lessIdx = ++counter;
                protos[state.curIdx] = proto = ((SplitNode)node).serialize(state.parentIdx, state.curIdx, greaterIdx, lessIdx);
                nodeQueue.offer(new SerializationState(state.curIdx, greaterIdx, ((SplitNode)node).getGreaterThan()));
                nodeQueue.offer(new SerializationState(state.curIdx, lessIdx, ((SplitNode)node).getLessThanOrEqual()));
                continue;
            }
            if (state.node instanceof LeafNode) {
                TreeNodeProto proto;
                node = (LeafNode)state.node;
                protos[state.curIdx] = proto = ((LeafNode)node).serialize(state.parentIdx, state.curIdx);
                continue;
            }
            throw new IllegalStateException("Invalid tree structure, contained a node which wasn't a SplitNode or a LeafNode, found " + state.node.getClass());
        }
        return Arrays.asList(protos);
    }

    private static final class SerializationState<T extends Output<T>> {
        final int parentIdx;
        final int curIdx;
        final Node<T> node;

        SerializationState(int parentIdx, int curIdx, Node<T> node) {
            this.parentIdx = parentIdx;
            this.curIdx = curIdx;
            this.node = node;
        }
    }

    static abstract class NodeBuilder {
        NodeBuilder() {
        }

        abstract int getParentIdx();

        abstract int getCurIdx();

        abstract Node<?> build();
    }
}

