/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.tree;

import hex.genmodel.MojoModel;
import hex.genmodel.algos.tree.NaSplitDir;
import hex.genmodel.utils.ByteBufferWrapper;
import hex.genmodel.utils.GenmodelBitSet;
import java.util.Arrays;

public abstract class SharedTreeMojoModel
extends MojoModel {
    private static final int NsdNaVsRest = NaSplitDir.NAvsREST.value();
    private static final int NsdNaLeft = NaSplitDir.NALeft.value();
    private static final int NsdLeft = NaSplitDir.Left.value();
    protected int _ntrees;
    protected int _ntrees_per_class;
    protected byte[][] _compressed_trees;

    public static double scoreTree(byte[] tree, double[] row, int nclasses, boolean computeLeafAssignment) {
        int lmask;
        ByteBufferWrapper ab = new ByteBufferWrapper(tree);
        GenmodelBitSet bs = null;
        long bitsRight = 0L;
        int level = 0;
        do {
            double d;
            int nodeType = ab.get1U();
            char colId = ab.get2();
            if (colId == '\uffff') {
                return ab.get4f();
            }
            int naSplitDir = ab.get1U();
            boolean naVsRest = naSplitDir == NsdNaVsRest;
            boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
            lmask = nodeType & 0x33;
            int equal = nodeType & 0xC;
            assert (equal != 4);
            float splitVal = -1.0f;
            if (!naVsRest) {
                if (equal == 0) {
                    splitVal = ab.get4f();
                } else {
                    if (bs == null) {
                        bs = new GenmodelBitSet(0);
                    }
                    if (equal == 8) {
                        bs.fill2(tree, ab);
                    } else {
                        bs.fill3(tree, ab);
                    }
                }
            }
            if (Double.isNaN(d = row[colId]) ? !leftward : !naVsRest && (equal == 0 ? d >= (double)splitVal : bs.contains((int)d))) {
                switch (lmask) {
                    case 0: {
                        ab.skip(ab.get1U());
                        break;
                    }
                    case 1: {
                        ab.skip(ab.get2());
                        break;
                    }
                    case 2: {
                        ab.skip(ab.get3());
                        break;
                    }
                    case 3: {
                        ab.skip(ab.get4());
                        break;
                    }
                    case 16: {
                        ab.skip(nclasses < 256 ? 1 : 2);
                        break;
                    }
                    case 48: {
                        ab.skip(4);
                        break;
                    }
                    default: {
                        assert (false) : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
                        break;
                    }
                }
                if (computeLeafAssignment && level < 64) {
                    bitsRight |= (long)(1 << level);
                }
                lmask = (nodeType & 0xC0) >> 2;
            } else if (lmask <= 3) {
                ab.skip(lmask + 1);
            }
            ++level;
        } while ((lmask & 0x10) == 0);
        if (computeLeafAssignment) {
            return Double.longBitsToDouble(bitsRight |= (long)(1 << level));
        }
        return ab.get4f();
    }

    public static double scoreTree(byte[] tree, double[] row, int nclasses) {
        return SharedTreeMojoModel.scoreTree(tree, row, nclasses, false);
    }

    public static String getDecisionPath(double leafAssignment) {
        long l = Double.doubleToRawLongBits(leafAssignment);
        StringBuilder sb = new StringBuilder();
        int pos = 0;
        for (int i = 0; i < 64; ++i) {
            boolean right = (l >> i & 1L) == 1L;
            sb.append(right ? "R" : "L");
            if (!right) continue;
            pos = i;
        }
        return sb.substring(0, pos);
    }

    protected SharedTreeMojoModel(String[] columns, String[][] domains) {
        super(columns, domains);
    }

    protected void scoreAllTrees(double[] row, double[] preds, int nClassesToScore) {
        Arrays.fill(preds, 0.0);
        for (int i = 0; i < nClassesToScore; ++i) {
            int k = this._nclasses == 1 ? 0 : i + 1;
            for (int j = 0; j < this._ntrees; ++j) {
                int itree = i * this._ntrees + j;
                int n = k;
                preds[n] = preds[n] + SharedTreeMojoModel.scoreTree(this._compressed_trees[itree], row, this._nclasses);
            }
        }
    }
}

