/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.tree;

import hex.genmodel.MojoModel;
import hex.genmodel.algos.tree.NaSplitDir;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.utils.ByteBufferWrapper;
import hex.genmodel.utils.GenmodelBitSet;
import java.util.Arrays;
import java.util.HashMap;

public abstract class SharedTreeMojoModel
extends MojoModel {
    private static final int NsdNaVsRest = NaSplitDir.NAvsREST.value();
    private static final int NsdNaLeft = NaSplitDir.NALeft.value();
    private static final int NsdLeft = NaSplitDir.Left.value();
    protected Number _mojo_version;
    protected int _ntree_groups;
    protected int _ntrees_per_group;
    protected byte[][] _compressed_trees;
    protected byte[][] _compressed_trees_aux;
    protected double[] _calib_glm_beta;

    public static double scoreTree(byte[] tree, double[] row, int nclasses, boolean computeLeafAssignment, String[][] domains) {
        int lmask;
        ByteBufferWrapper ab = new ByteBufferWrapper(tree);
        GenmodelBitSet bs = null;
        long bitsRight = 0L;
        int level = 0;
        do {
            double d;
            int nodeType = ab.get1U();
            char colId = ab.get2();
            if (colId == '\uffff') {
                return ab.get4f();
            }
            int naSplitDir = ab.get1U();
            boolean naVsRest = naSplitDir == NsdNaVsRest;
            boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
            lmask = nodeType & 0x33;
            int equal = nodeType & 0xC;
            assert (equal != 4);
            float splitVal = -1.0f;
            if (!naVsRest) {
                if (equal == 0) {
                    splitVal = ab.get4f();
                } else {
                    if (bs == null) {
                        bs = new GenmodelBitSet(0);
                    }
                    if (equal == 8) {
                        bs.fill2(tree, ab);
                    } else {
                        bs.fill3(tree, ab);
                    }
                }
            }
            if (Double.isNaN(d = row[colId]) || equal != 0 && bs != null && !bs.isInRange((int)d) || domains != null && domains[colId] != null && domains[colId].length <= (int)d ? !leftward : !naVsRest && (equal == 0 ? d >= (double)splitVal : bs.contains((int)d))) {
                switch (lmask) {
                    case 0: {
                        ab.skip(ab.get1U());
                        break;
                    }
                    case 1: {
                        ab.skip(ab.get2());
                        break;
                    }
                    case 2: {
                        ab.skip(ab.get3());
                        break;
                    }
                    case 3: {
                        ab.skip(ab.get4());
                        break;
                    }
                    case 16: {
                        ab.skip(nclasses < 256 ? 1 : 2);
                        break;
                    }
                    case 48: {
                        ab.skip(4);
                        break;
                    }
                    default: {
                        assert (false) : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
                        break;
                    }
                }
                if (computeLeafAssignment && level < 64) {
                    bitsRight |= (long)(1 << level);
                }
                lmask = (nodeType & 0xC0) >> 2;
            } else if (lmask <= 3) {
                ab.skip(lmask + 1);
            }
            ++level;
        } while ((lmask & 0x10) == 0);
        if (computeLeafAssignment) {
            return Double.longBitsToDouble(bitsRight |= (long)(1 << level));
        }
        return ab.get4f();
    }

    public static String getDecisionPath(double leafAssignment) {
        long l = Double.doubleToRawLongBits(leafAssignment);
        StringBuilder sb = new StringBuilder();
        int pos = 0;
        for (int i = 0; i < 64; ++i) {
            boolean right = (l >> i & 1L) == 1L;
            sb.append(right ? "R" : "L");
            if (!right) continue;
            pos = i;
        }
        return sb.substring(0, pos);
    }

    private void computeTreeGraph(SharedTreeSubgraph sg, SharedTreeNode node, byte[] tree, ByteBufferWrapper ab, HashMap<Integer, AuxInfo> auxMap, int nclasses) {
        int nodeType = ab.get1U();
        char colId = ab.get2();
        if (colId == '\uffff') {
            float leafValue = ab.get4f();
            node.setPredValue(leafValue);
            return;
        }
        String colName = this.getNames()[colId];
        node.setCol(colId, colName);
        int naSplitDir = ab.get1U();
        boolean naVsRest = naSplitDir == NsdNaVsRest;
        boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
        node.setLeftward(leftward);
        node.setNaVsRest(naVsRest);
        int lmask = nodeType & 0x33;
        int equal = nodeType & 0xC;
        assert (equal != 4);
        if (!naVsRest) {
            if (equal == 0) {
                float splitVal = ab.get4f();
                node.setSplitValue(splitVal);
            } else {
                GenmodelBitSet bs = new GenmodelBitSet(0);
                if (equal == 8) {
                    bs.fill2(tree, ab);
                } else {
                    bs.fill3(tree, ab);
                }
                node.setBitset(this.getDomainValues(colId), bs);
            }
        }
        AuxInfo auxInfo = auxMap.get(node.getNodeNumber());
        ByteBufferWrapper ab2 = new ByteBufferWrapper(tree);
        ab2.skip(ab.position());
        switch (lmask) {
            case 0: {
                ab2.skip(ab2.get1U());
                break;
            }
            case 1: {
                ab2.skip(ab2.get2());
                break;
            }
            case 2: {
                ab2.skip(ab2.get3());
                break;
            }
            case 3: {
                ab2.skip(ab2.get4());
                break;
            }
            case 16: {
                ab2.skip(nclasses < 256 ? 1 : 2);
                break;
            }
            case 48: {
                ab2.skip(4);
                break;
            }
            default: {
                assert (false) : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
                break;
            }
        }
        int lmask2 = (nodeType & 0xC0) >> 2;
        SharedTreeNode newNode = sg.makeRightChildNode(node);
        newNode.setWeight(auxInfo.weightR);
        newNode.setNodeNumber(auxInfo.nidR);
        newNode.setPredValue(auxInfo.predR);
        newNode.setSquaredError(auxInfo.sqErrR);
        if ((lmask2 & 0x10) != 0) {
            float leafValue = ab2.get4f();
            newNode.setPredValue(leafValue);
            auxInfo.predR = leafValue;
        } else {
            this.computeTreeGraph(sg, newNode, tree, ab2, auxMap, nclasses);
        }
        ab2 = new ByteBufferWrapper(tree);
        ab2.skip(ab.position());
        if (lmask <= 3) {
            ab2.skip(lmask + 1);
        }
        SharedTreeNode newNode2 = sg.makeLeftChildNode(node);
        newNode2.setWeight(auxInfo.weightL);
        newNode2.setNodeNumber(auxInfo.nidL);
        newNode2.setPredValue(auxInfo.predL);
        newNode2.setSquaredError(auxInfo.sqErrL);
        if ((lmask & 0x10) != 0) {
            float leafValue = ab2.get4f();
            newNode2.setPredValue(leafValue);
            auxInfo.predL = leafValue;
        } else {
            this.computeTreeGraph(sg, newNode2, tree, ab2, auxMap, nclasses);
        }
        if (node.getNodeNumber() == 0) {
            float p = (float)(((double)auxInfo.predL * (double)auxInfo.weightL + (double)auxInfo.predR * (double)auxInfo.weightR) / ((double)auxInfo.weightL + (double)auxInfo.weightR));
            if ((double)Math.abs(p) < 1.0E-7) {
                p = 0.0f;
            }
            node.setPredValue(p);
            node.setSquaredError(auxInfo.sqErrR + auxInfo.sqErrL);
            node.setWeight(auxInfo.weightL + auxInfo.weightR);
        }
        this.checkConsistency(auxInfo, node);
    }

    public SharedTreeGraph _computeGraph(int treeToPrint) {
        SharedTreeGraph g = new SharedTreeGraph();
        if (treeToPrint >= this._ntree_groups) {
            throw new IllegalArgumentException("Tree " + treeToPrint + " does not exist (max " + this._ntree_groups + ")");
        }
        for (int j = treeToPrint >= 0 ? treeToPrint : 0; j < this._ntree_groups; ++j) {
            for (int i = 0; i < this._ntrees_per_group; ++i) {
                String className = "";
                String[] domainValues = this.getDomainValues(this.getResponseIdx());
                if (domainValues != null) {
                    className = ", Class " + domainValues[i];
                }
                int itree = this.treeIndex(j, i);
                SharedTreeSubgraph sg = g.makeSubgraph("Tree " + j + className);
                SharedTreeNode node = sg.makeRootNode();
                node.setSquaredError(Float.NaN);
                node.setPredValue(Float.NaN);
                byte[] tree = this._compressed_trees[itree];
                ByteBufferWrapper ab = new ByteBufferWrapper(tree);
                ByteBufferWrapper abAux = new ByteBufferWrapper(this._compressed_trees_aux[itree]);
                HashMap<Integer, AuxInfo> auxMap = new HashMap<Integer, AuxInfo>();
                while (abAux.hasRemaining()) {
                    AuxInfo auxInfo = new AuxInfo(abAux);
                    auxMap.put(auxInfo.nid, auxInfo);
                }
                this.computeTreeGraph(sg, node, tree, ab, auxMap, this._nclasses);
            }
            if (treeToPrint >= 0) break;
        }
        return g;
    }

    void checkConsistency(AuxInfo auxInfo, SharedTreeNode node) {
        boolean ok = true;
        ok &= auxInfo.nid == node.getNodeNumber();
        double sum = 0.0;
        if (node.leftChild != null) {
            ok &= auxInfo.nidL == node.leftChild.getNodeNumber();
            ok &= auxInfo.weightL == node.leftChild.getWeight();
            ok &= auxInfo.predL == node.leftChild.predValue;
            ok &= auxInfo.sqErrL == node.leftChild.squaredError;
            sum += (double)node.leftChild.getWeight();
        }
        if (node.rightChild != null) {
            ok &= auxInfo.nidR == node.rightChild.getNodeNumber();
            ok &= auxInfo.weightR == node.rightChild.getWeight();
            ok &= auxInfo.predR == node.rightChild.predValue;
            ok &= auxInfo.sqErrR == node.rightChild.squaredError;
            sum += (double)node.rightChild.getWeight();
        }
        if (node.parent != null) {
            ok &= auxInfo.pid == node.parent.getNodeNumber();
            ok &= Math.abs((double)node.getWeight() - sum) < 1.0E-5 * ((double)node.getWeight() + sum);
        }
        if (!ok) {
            System.out.println("\nTree inconsistency found:");
            node.print();
            node.leftChild.print();
            node.rightChild.print();
            System.out.println(auxInfo.toString());
        }
    }

    protected SharedTreeMojoModel(String[] columns, String[][] domains) {
        super(columns, domains);
    }

    protected void scoreAllTrees(double[] row, double[] preds) {
        Arrays.fill(preds, 0.0);
        for (int i = 0; i < this._ntrees_per_group; ++i) {
            int k = this._nclasses == 1 ? 0 : i + 1;
            for (int j = 0; j < this._ntree_groups; ++j) {
                int itree = this.treeIndex(j, i);
                if (this._compressed_trees[itree] == null) continue;
                if (this._mojo_version.equals(1.0)) {
                    int n = k;
                    preds[n] = preds[n] + SharedTreeMojoModel.scoreTree0(this._compressed_trees[itree], row, this._nclasses, false);
                    continue;
                }
                if (this._mojo_version.equals(1.1)) {
                    int n = k;
                    preds[n] = preds[n] + SharedTreeMojoModel.scoreTree1(this._compressed_trees[itree], row, this._nclasses, false);
                    continue;
                }
                if (!this._mojo_version.equals(1.2)) continue;
                int n = k;
                preds[n] = preds[n] + SharedTreeMojoModel.scoreTree(this._compressed_trees[itree], row, this._nclasses, false, this._domains);
            }
        }
    }

    protected int treeIndex(int groupIndex, int classIndex) {
        return classIndex * this._ntree_groups + groupIndex;
    }

    public static double scoreTree0(byte[] tree, double[] row, int nclasses, boolean computeLeafAssignment) {
        int lmask;
        ByteBufferWrapper ab = new ByteBufferWrapper(tree);
        GenmodelBitSet bs = null;
        long bitsRight = 0L;
        int level = 0;
        do {
            double d;
            int nodeType = ab.get1U();
            char colId = ab.get2();
            if (colId == '\uffff') {
                return ab.get4f();
            }
            int naSplitDir = ab.get1U();
            boolean naVsRest = naSplitDir == NsdNaVsRest;
            boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
            lmask = nodeType & 0x33;
            int equal = nodeType & 0xC;
            assert (equal != 4);
            float splitVal = -1.0f;
            if (!naVsRest) {
                if (equal == 0) {
                    splitVal = ab.get4f();
                } else {
                    if (bs == null) {
                        bs = new GenmodelBitSet(0);
                    }
                    if (equal == 8) {
                        bs.fill2(tree, ab);
                    } else {
                        bs.fill3_1(tree, ab);
                    }
                }
            }
            if (Double.isNaN(d = row[colId]) ? !leftward : !naVsRest && (equal == 0 ? d >= (double)splitVal : bs.contains0((int)d))) {
                switch (lmask) {
                    case 0: {
                        ab.skip(ab.get1U());
                        break;
                    }
                    case 1: {
                        ab.skip(ab.get2());
                        break;
                    }
                    case 2: {
                        ab.skip(ab.get3());
                        break;
                    }
                    case 3: {
                        ab.skip(ab.get4());
                        break;
                    }
                    case 16: {
                        ab.skip(nclasses < 256 ? 1 : 2);
                        break;
                    }
                    case 48: {
                        ab.skip(4);
                        break;
                    }
                    default: {
                        assert (false) : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
                        break;
                    }
                }
                if (computeLeafAssignment && level < 64) {
                    bitsRight |= (long)(1 << level);
                }
                lmask = (nodeType & 0xC0) >> 2;
            } else if (lmask <= 3) {
                ab.skip(lmask + 1);
            }
            ++level;
        } while ((lmask & 0x10) == 0);
        if (computeLeafAssignment) {
            return Double.longBitsToDouble(bitsRight |= (long)(1 << level));
        }
        return ab.get4f();
    }

    public static double scoreTree1(byte[] tree, double[] row, int nclasses, boolean computeLeafAssignment) {
        int lmask;
        ByteBufferWrapper ab = new ByteBufferWrapper(tree);
        GenmodelBitSet bs = null;
        long bitsRight = 0L;
        int level = 0;
        do {
            double d;
            int nodeType = ab.get1U();
            char colId = ab.get2();
            if (colId == '\uffff') {
                return ab.get4f();
            }
            int naSplitDir = ab.get1U();
            boolean naVsRest = naSplitDir == NsdNaVsRest;
            boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
            lmask = nodeType & 0x33;
            int equal = nodeType & 0xC;
            assert (equal != 4);
            float splitVal = -1.0f;
            if (!naVsRest) {
                if (equal == 0) {
                    splitVal = ab.get4f();
                } else {
                    if (bs == null) {
                        bs = new GenmodelBitSet(0);
                    }
                    if (equal == 8) {
                        bs.fill2(tree, ab);
                    } else {
                        bs.fill3_1(tree, ab);
                    }
                }
            }
            if (Double.isNaN(d = row[colId]) || equal != 0 && bs != null && !bs.isInRange((int)d) ? !leftward : !naVsRest && (equal == 0 ? d >= (double)splitVal : bs.contains((int)d))) {
                switch (lmask) {
                    case 0: {
                        ab.skip(ab.get1U());
                        break;
                    }
                    case 1: {
                        ab.skip(ab.get2());
                        break;
                    }
                    case 2: {
                        ab.skip(ab.get3());
                        break;
                    }
                    case 3: {
                        ab.skip(ab.get4());
                        break;
                    }
                    case 16: {
                        ab.skip(nclasses < 256 ? 1 : 2);
                        break;
                    }
                    case 48: {
                        ab.skip(4);
                        break;
                    }
                    default: {
                        assert (false) : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
                        break;
                    }
                }
                if (computeLeafAssignment && level < 64) {
                    bitsRight |= (long)(1 << level);
                }
                lmask = (nodeType & 0xC0) >> 2;
            } else if (lmask <= 3) {
                ab.skip(lmask + 1);
            }
            ++level;
        } while ((lmask & 0x10) == 0);
        if (computeLeafAssignment) {
            return Double.longBitsToDouble(bitsRight |= (long)(1 << level));
        }
        return ab.get4f();
    }

    @Override
    public boolean calibrateClassProbabilities(double[] preds) {
        if (this._calib_glm_beta == null) {
            return false;
        }
        assert (this._nclasses == 2);
        assert (preds.length == this._nclasses + 1);
        double p = SharedTreeMojoModel.GLM_logitInv(preds[1] * this._calib_glm_beta[0] + this._calib_glm_beta[1]);
        preds[1] = 1.0 - p;
        preds[2] = p;
        return true;
    }

    static class AuxInfo {
        public int nid;
        public int pid;
        public int nidL;
        public int nidR;
        public float weightL;
        public float weightR;
        public float predL;
        public float predR;
        public float sqErrL;
        public float sqErrR;

        AuxInfo(ByteBufferWrapper abAux) {
            this.nid = abAux.get4();
            this.pid = abAux.get4();
            this.weightL = abAux.get4f();
            this.weightR = abAux.get4f();
            this.predL = abAux.get4f();
            this.predR = abAux.get4f();
            this.sqErrL = abAux.get4f();
            this.sqErrR = abAux.get4f();
            this.nidL = abAux.get4();
            this.nidR = abAux.get4();
        }

        public String toString() {
            return "nid: " + this.nid + "\n" + "pid: " + this.pid + "\n" + "nidL: " + this.nidL + "\n" + "nidR: " + this.nidR + "\n" + "weightL: " + this.weightL + "\n" + "weightR: " + this.weightR + "\n" + "predL: " + this.predL + "\n" + "predR: " + this.predR + "\n" + "sqErrL: " + this.sqErrL + "\n" + "sqErrR: " + this.sqErrR + "\n";
        }
    }
}

