/*
 * Decompiled with CFR 0.152.
 */
package hivemall.smile.classification;

import hivemall.annotations.VisibleForTesting;
import hivemall.smile.classification.PredictionHandler;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.smile.utils.VariableOrder;
import hivemall.utils.collections.arrays.SparseIntArray;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.function.Consumer;
import hivemall.utils.function.IntPredicate;
import hivemall.utils.lang.ArrayUtils;
import hivemall.utils.lang.ObjectUtils;
import hivemall.utils.lang.StringUtils;
import hivemall.utils.lang.mutable.MutableInt;
import hivemall.utils.random.PRNG;
import hivemall.utils.random.RandomNumberGeneratorFactory;
import hivemall.utils.sampling.IntReservoirSampler;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import matrix4j.matrix.Matrix;
import matrix4j.vector.DenseVector;
import matrix4j.vector.SparseVector;
import matrix4j.vector.Vector;
import matrix4j.vector.VectorProcedure;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.roaringbitmap.IntConsumer;
import org.roaringbitmap.RoaringBitmap;
import smile.classification.Classifier;
import smile.math.Math;

public class DecisionTree
implements Classifier<Vector> {
    private static final Log logger = LogFactory.getLog(DecisionTree.class);
    @Nonnull
    private final Matrix _X;
    @Nonnull
    private final int[] _y;
    @Nonnull
    private final int[] _samples;
    @Nonnull
    private final VariableOrder _order;
    @Nonnull
    private final int[] _sampleIndex;
    @Nonnull
    private final RoaringBitmap _nominalAttrs;
    @Nonnull
    private final Vector _importance;
    @Nonnull
    private final Node _root;
    private final int _maxDepth;
    @Nonnull
    private final SplitRule _rule;
    private final int _k;
    private final int _numVars;
    private final int _minSamplesSplit;
    private final int _minSamplesLeaf;
    @Nonnull
    private final PRNG _rnd;

    private static void indent(StringBuilder builder, int depth) {
        for (int i = 0; i < depth; ++i) {
            builder.append("  ");
        }
    }

    private static void partitionArray(@Nonnull SparseIntArray a, int low, int pivot, int high, @Nonnull IntPredicate goesLeft, @Nonnull int[] buf) {
        int i;
        int[] rowIndexes = a.keys();
        int[] rowPtrs = a.values();
        int size = a.size();
        int startPos = ArrayUtils.insertionPoint(rowIndexes, size, low);
        int endPos = ArrayUtils.insertionPoint(rowIndexes, size, high);
        int pos = startPos;
        int k = 0;
        int j = low;
        for (i = startPos; i < endPos; ++i) {
            int rowPtr = rowPtrs[i];
            if (goesLeft.test(rowPtr)) {
                rowIndexes[pos] = j++;
                rowPtrs[pos] = rowPtr;
                ++pos;
                continue;
            }
            if (k >= buf.length) {
                throw new IndexOutOfBoundsException(String.format("low=%d, pivot=%d, high=%d, a.size()=%d, buf.length=%d, i=%d, j=%d, k=%d, startPos=%d, endPos=%d\na=%s\nbuf=%s", low, pivot, high, a.size(), buf.length, i, j, k, startPos, endPos, a.toString(), Arrays.toString(buf)));
            }
            buf[k++] = rowPtr;
        }
        for (i = 0; i < k; ++i) {
            rowIndexes[pos] = pivot + i;
            rowPtrs[pos] = buf[i];
            ++pos;
        }
        if (pos != endPos) {
            throw new IllegalStateException(String.format("pos=%d, startPos=%d, endPos=%d, k=%d\na=%s", pos, startPos, endPos, k, a.toString()));
        }
    }

    private static void partitionArray(@Nonnull int[] a, int low, int pivot, int high, @Nonnull IntPredicate goesLeft, @Nonnull int[] buf) {
        int j = low;
        int k = 0;
        for (int i = low; i < high; ++i) {
            if (i >= a.length) {
                throw new IndexOutOfBoundsException(String.format("low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, i=%d, j=%d, k=%d", low, pivot, high, a.length, buf.length, i, j, k));
            }
            int rowPtr = a[i];
            if (goesLeft.test(rowPtr)) {
                a[j++] = rowPtr;
                continue;
            }
            if (k >= buf.length) {
                throw new IndexOutOfBoundsException(String.format("low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, i=%d, j=%d, k=%d", low, pivot, high, a.length, buf.length, i, j, k));
            }
            buf[k++] = rowPtr;
        }
        if (k != high - pivot || j != pivot) {
            throw new IndexOutOfBoundsException(String.format("low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, j=%d, k=%d", low, pivot, high, a.length, buf.length, j, k));
        }
        System.arraycopy(buf, 0, a, pivot, k);
    }

    private static double impurity(@Nonnull int[] count, int n, @Nonnull SplitRule rule) {
        double impurity = 0.0;
        switch (rule) {
            case GINI: {
                impurity = 1.0;
                for (int count_i : count) {
                    if (count_i <= 0) continue;
                    double p = (double)count_i / (double)n;
                    impurity -= p * p;
                }
                break;
            }
            case ENTROPY: {
                for (int count_i : count) {
                    if (count_i <= 0) continue;
                    double p = (double)count_i / (double)n;
                    impurity -= p * Math.log2(p);
                }
                break;
            }
            case CLASSIFICATION_ERROR: {
                impurity = 0.0;
                for (int count_i : count) {
                    if (count_i <= 0) continue;
                    impurity = Math.max(impurity, (double)count_i / (double)n);
                }
                impurity = Math.abs(1.0 - impurity);
            }
        }
        return impurity;
    }

    private static void pruneRedundantLeaves(@Nonnull Node node, @Nonnull Vector importance) {
        if (node.isLeaf()) {
            return;
        }
        DecisionTree.pruneRedundantLeaves(node.trueChild, importance);
        DecisionTree.pruneRedundantLeaves(node.falseChild, importance);
        if (node.trueChild.isLeaf() && node.falseChild.isLeaf() && node.trueChild.output == node.falseChild.output) {
            node.trueChild = null;
            node.falseChild = null;
            importance.decr(node.splitFeature, node.splitScore);
        } else {
            node.posteriori = null;
        }
    }

    public DecisionTree(@Nullable RoaringBitmap nominalAttrs, @Nonnull Matrix x, @Nonnull int[] y, int numSamplesLeaf) {
        this(nominalAttrs, x, y, x.numColumns(), Integer.MAX_VALUE, numSamplesLeaf, 2, 1, null, SplitRule.GINI, null);
    }

    public DecisionTree(@Nullable RoaringBitmap nominalAttrs, @Nullable Matrix x, @Nullable int[] y, int numSamplesLeaf, @Nullable PRNG rand) {
        this(nominalAttrs, x, y, x.numColumns(), Integer.MAX_VALUE, numSamplesLeaf, 2, 1, null, SplitRule.GINI, rand);
    }

    public DecisionTree(@Nullable RoaringBitmap nominalAttrs, @Nonnull Matrix x, @Nonnull int[] y, int numVars, int maxDepth, int maxLeafNodes, int minSamplesSplit, int minSamplesLeaf, @Nullable int[] samples, @Nonnull SplitRule rule, @Nullable PRNG rand) {
        int i;
        int[] sampleIndex;
        DecisionTree.checkArgument(x, y, numVars, maxDepth, maxLeafNodes, minSamplesSplit, minSamplesLeaf);
        this._X = x;
        this._y = y;
        this._k = Math.max(y) + 1;
        if (this._k < 2) {
            throw new IllegalArgumentException("Only one class or negative class labels.");
        }
        if (nominalAttrs == null) {
            nominalAttrs = new RoaringBitmap();
        }
        this._nominalAttrs = nominalAttrs;
        this._numVars = numVars;
        this._maxDepth = maxDepth;
        if (minSamplesSplit < minSamplesLeaf * 2) {
            if (logger.isInfoEnabled()) {
                logger.info((Object)String.format("min_sample_leaf = %d replaces min_sample_split = %d with min_sample_split = %d", minSamplesLeaf, minSamplesSplit, minSamplesLeaf * 2));
            }
            minSamplesSplit = minSamplesLeaf * 2;
        }
        this._minSamplesSplit = minSamplesSplit;
        this._minSamplesLeaf = minSamplesLeaf;
        this._rule = rule;
        this._importance = x.isSparse() ? new SparseVector() : new DenseVector(x.numColumns());
        this._rnd = rand == null ? RandomNumberGeneratorFactory.createPRNG() : rand;
        int n = y.length;
        int[] count = new int[this._k];
        int totalNumSamples = 0;
        if (samples == null) {
            samples = new int[n];
            sampleIndex = new int[n];
            for (int i2 = 0; i2 < n; ++i2) {
                samples[i2] = 1;
                int n2 = y[i2];
                count[n2] = count[n2] + 1;
                sampleIndex[i2] = i2;
            }
            totalNumSamples = n;
        } else {
            IntArrayList positions = new IntArrayList(n);
            for (i = 0; i < n; ++i) {
                int sample = samples[i];
                if (sample == 0) continue;
                int n3 = y[i];
                count[n3] = count[n3] + sample;
                positions.add(i);
                totalNumSamples += sample;
            }
            sampleIndex = positions.toArray(true);
        }
        this._samples = samples;
        this._order = SmileExtUtils.sort(nominalAttrs, x, samples);
        this._sampleIndex = sampleIndex;
        double[] posteriori = new double[this._k];
        for (i = 0; i < this._k; ++i) {
            posteriori[i] = (double)count[i] / (double)n;
        }
        this._root = new Node(Math.whichMax(count), posteriori);
        TrainNode trainRoot = new TrainNode(this._root, 1, 0, this._sampleIndex.length, totalNumSamples);
        if (maxLeafNodes == Integer.MAX_VALUE) {
            if (trainRoot.findBestSplit()) {
                trainRoot.split(null);
            }
        } else {
            TrainNode node;
            PriorityQueue<TrainNode> nextSplits = new PriorityQueue<TrainNode>();
            if (trainRoot.findBestSplit()) {
                nextSplits.add(trainRoot);
            }
            for (int leaves = 1; leaves < maxLeafNodes && (node = (TrainNode)nextSplits.poll()) != null; ++leaves) {
                if (node.split(nextSplits)) continue;
                --leaves;
            }
            DecisionTree.pruneRedundantLeaves(this._root, this._importance);
        }
    }

    @VisibleForTesting
    Node getRootNode() {
        return this._root;
    }

    private static void checkArgument(@Nonnull Matrix x, @Nonnull int[] y, int numVars, int maxDepth, int maxLeafNodes, int minSamplesSplit, int minSamplesLeaf) {
        if (x.numRows() != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.numRows(), y.length));
        }
        if (y.length == 0) {
            throw new IllegalArgumentException("No training example given");
        }
        if (numVars <= 0 || numVars > x.numColumns()) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + numVars);
        }
        if (maxDepth < 2) {
            throw new IllegalArgumentException("maxDepth should be greater than 1: " + maxDepth);
        }
        if (maxLeafNodes < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + maxLeafNodes);
        }
        if (minSamplesSplit < 2) {
            throw new IllegalArgumentException("Invalid minimum number of samples required to split an internal node: " + minSamplesSplit);
        }
        if (minSamplesLeaf < 1) {
            throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + minSamplesLeaf);
        }
    }

    @Nonnull
    public Vector importance() {
        return this._importance;
    }

    @Override
    @VisibleForTesting
    public int predict(@Nonnull double[] x) {
        return this.predict(new DenseVector(x));
    }

    @Override
    public int predict(@Nonnull Vector x) {
        return this._root.predict(x);
    }

    public void predict(@Nonnull Vector x, @Nonnull PredictionHandler handler) {
        this._root.predict(x, handler);
    }

    @Override
    public int predict(Vector x, double[] posteriori) {
        throw new UnsupportedOperationException("Not supported.");
    }

    @Nonnull
    public String predictJsCodegen(@Nonnull String[] featureNames, @Nonnull String[] classNames) {
        StringBuilder buf = new StringBuilder(1024);
        this._root.exportJavascript(buf, featureNames, classNames, 0);
        return buf.toString();
    }

    @Deprecated
    @Nonnull
    public String predictOpCodegen(@Nonnull String sep) {
        ArrayList<String> opslist = new ArrayList<String>();
        this._root.opCodegen(opslist, 0);
        opslist.add("call end");
        String scripts = StringUtils.concat(opslist, sep);
        return scripts;
    }

    @Nonnull
    public byte[] serialize(boolean compress) throws HiveException {
        try {
            if (compress) {
                return ObjectUtils.toCompressedBytes(this._root);
            }
            return ObjectUtils.toBytes(this._root);
        }
        catch (IOException ioe) {
            throw new HiveException("IOException cause while serializing DecisionTree object", (Throwable)ioe);
        }
        catch (Exception e) {
            throw new HiveException("Exception cause while serializing DecisionTree object", (Throwable)e);
        }
    }

    @Nonnull
    public static Node deserialize(@Nonnull byte[] serializedObj, int length, boolean compressed) throws HiveException {
        Node root = new Node();
        try {
            if (compressed) {
                ObjectUtils.readCompressedObject(serializedObj, 0, length, root);
            } else {
                ObjectUtils.readObject(serializedObj, length, root);
            }
        }
        catch (IOException ioe) {
            throw new HiveException("IOException cause while deserializing DecisionTree object", (Throwable)ioe);
        }
        catch (Exception e) {
            throw new HiveException("Exception cause while deserializing DecisionTree object", (Throwable)e);
        }
        return root;
    }

    public String toString() {
        return this._root == null ? "" : this.predictJsCodegen(null, null);
    }

    private final class TrainNode
    implements Comparable<TrainNode> {
        @Nonnull
        final Node node;
        final int depth;
        final int low;
        final int high;
        final int samples;
        @Nullable
        int[] constFeatures;

        public TrainNode(Node node, int depth, int low, int high, int samples) {
            this(node, depth, low, high, samples, new int[0]);
        }

        public TrainNode(Node node, int depth, int low, int high, @Nonnull int samples, int[] constFeatures) {
            if (low >= high) {
                throw new IllegalArgumentException("Unexpected condition was met. low=" + low + ", high=" + high);
            }
            this.node = node;
            this.depth = depth;
            this.low = low;
            this.high = high;
            this.samples = samples;
            this.constFeatures = constFeatures;
        }

        @Override
        public int compareTo(TrainNode a) {
            return (int)Math.signum(a.node.splitScore - this.node.splitScore);
        }

        public boolean findBestSplit() {
            if (this.depth >= DecisionTree.this._maxDepth) {
                return false;
            }
            if (this.samples <= DecisionTree.this._minSamplesSplit) {
                return false;
            }
            int[] count = new int[DecisionTree.this._k];
            boolean pure = this.countSamples(count);
            if (pure) {
                return false;
            }
            int[] constFeatures_ = this.constFeatures;
            double impurity = DecisionTree.impurity(count, this.samples, DecisionTree.this._rule);
            int[] falseCount = new int[DecisionTree.this._k];
            for (int varJ : this.variableIndex()) {
                if (ArrayUtils.contains(constFeatures_, varJ)) continue;
                Node split = this.findBestSplit(this.samples, count, falseCount, impurity, varJ);
                if (!(split.splitScore > this.node.splitScore)) continue;
                this.node.splitFeature = split.splitFeature;
                this.node.quantitativeFeature = split.quantitativeFeature;
                this.node.splitValue = split.splitValue;
                this.node.splitScore = split.splitScore;
            }
            return this.node.splitFeature != -1;
        }

        @Nonnull
        private int[] variableIndex() {
            Matrix X = DecisionTree.this._X;
            final IntReservoirSampler sampler = new IntReservoirSampler(DecisionTree.this._numVars, DecisionTree.this._rnd.nextLong());
            if (X.isSparse()) {
                final RoaringBitmap cols = new RoaringBitmap();
                VectorProcedure proc = new VectorProcedure(){

                    @Override
                    public void apply(int col) {
                        cols.add(col);
                    }
                };
                int[] sampleIndex = DecisionTree.this._sampleIndex;
                int end = this.high;
                for (int i = this.low; i < end; ++i) {
                    int row = sampleIndex[i];
                    assert (DecisionTree.this._samples[row] != 0) : row;
                    X.eachColumnIndexInRow(row, proc);
                }
                cols.forEach(new IntConsumer(){

                    @Override
                    public void accept(int k) {
                        sampler.add(k);
                    }
                });
            } else {
                int ncols = X.numColumns();
                for (int i = 0; i < ncols; ++i) {
                    sampler.add(i);
                }
            }
            return sampler.getSample();
        }

        private boolean countSamples(@Nonnull int[] count) {
            int[] sampleIndex = DecisionTree.this._sampleIndex;
            int[] samples = DecisionTree.this._samples;
            int[] y = DecisionTree.this._y;
            boolean pure = true;
            int end = this.high;
            int label = -1;
            for (int i = this.low; i < end; ++i) {
                int y_i;
                int index = sampleIndex[i];
                int n = y_i = y[index];
                count[n] = count[n] + samples[index];
                if (label == -1) {
                    label = y_i;
                    continue;
                }
                if (y_i == label) continue;
                pure = false;
            }
            return pure;
        }

        private Node findBestSplit(final int n, final int[] count, final int[] falseCount, final double impurity, final int j) {
            final int[] samples = DecisionTree.this._samples;
            int[] sampleIndex = DecisionTree.this._sampleIndex;
            final Matrix X = DecisionTree.this._X;
            final int[] y = DecisionTree.this._y;
            final int classes = DecisionTree.this._k;
            final Node splitNode = new Node();
            if (DecisionTree.this._nominalAttrs.contains(j)) {
                Int2ObjectOpenHashMap<int[]> trueCount = new Int2ObjectOpenHashMap<int[]>();
                int countNaN = 0;
                int end = this.high;
                for (int i = this.low; i < end; ++i) {
                    int y_i;
                    int n2 = sampleIndex[i];
                    int numSamples = samples[n2];
                    if (numSamples == 0) continue;
                    double v = X.get(n2, j, Double.NaN);
                    if (Double.isNaN(v)) {
                        ++countNaN;
                        continue;
                    }
                    int x_ij = (int)v;
                    int[] tc_x = (int[])trueCount.get(x_ij);
                    if (tc_x == null) {
                        tc_x = new int[classes];
                        trueCount.put(x_ij, tc_x);
                    }
                    int n3 = y_i = y[n2];
                    tc_x[n3] = tc_x[n3] + numSamples;
                }
                int countDistinctX = trueCount.size() + (countNaN == 0 ? 0 : 1);
                if (countDistinctX <= 1) {
                    this.constFeatures = ArrayUtils.sortedArraySet(this.constFeatures, j);
                }
                for (Int2ObjectMap.Entry entry : trueCount.int2ObjectEntrySet()) {
                    int l = entry.getIntKey();
                    int[] trueCount_l = (int[])entry.getValue();
                    int tc = Math.sum(trueCount_l);
                    int fc = n - tc;
                    if (tc < DecisionTree.this._minSamplesSplit || fc < DecisionTree.this._minSamplesSplit) continue;
                    for (int k = 0; k < classes; ++k) {
                        falseCount[k] = count[k] - trueCount_l[k];
                    }
                    double gain = impurity - (double)tc / (double)n * DecisionTree.impurity(trueCount_l, tc, DecisionTree.this._rule) - (double)fc / (double)n * DecisionTree.impurity(falseCount, fc, DecisionTree.this._rule);
                    if (!(gain > splitNode.splitScore)) continue;
                    splitNode.splitFeature = j;
                    splitNode.quantitativeFeature = false;
                    splitNode.splitValue = l;
                    splitNode.splitScore = gain;
                }
            } else {
                final int[] trueCount = new int[classes];
                final MutableInt countNaN = new MutableInt(0);
                final MutableInt replaceCount = new MutableInt(0);
                DecisionTree.this._order.eachNonNullInColumn(j, this.low, this.high, new Consumer(){
                    double prevx = Double.NaN;
                    double lastx = Double.NaN;
                    int prevy = -1;

                    @Override
                    public void accept(int pos, int i) {
                        int numSamples = samples[i];
                        if (numSamples == 0) {
                            return;
                        }
                        double x_ij = X.get(i, j, Double.NaN);
                        if (Double.isNaN(x_ij)) {
                            countNaN.incr();
                            return;
                        }
                        if (this.lastx != x_ij) {
                            this.lastx = x_ij;
                            replaceCount.incr();
                        }
                        int y_i = y[i];
                        if (Double.isNaN(this.prevx) || x_ij == this.prevx || y_i == this.prevy) {
                            this.prevx = x_ij;
                            this.prevy = y_i;
                            int n2 = y_i;
                            trueCount[n2] = trueCount[n2] + numSamples;
                            return;
                        }
                        int tc = Math.sum(trueCount);
                        int fc = n - tc;
                        if (tc < DecisionTree.this._minSamplesSplit || fc < DecisionTree.this._minSamplesSplit) {
                            this.prevx = x_ij;
                            this.prevy = y_i;
                            int n3 = y_i;
                            trueCount[n3] = trueCount[n3] + numSamples;
                            return;
                        }
                        for (int l = 0; l < classes; ++l) {
                            falseCount[l] = count[l] - trueCount[l];
                        }
                        double gain = impurity - (double)tc / (double)n * DecisionTree.impurity(trueCount, tc, DecisionTree.this._rule) - (double)fc / (double)n * DecisionTree.impurity(falseCount, fc, DecisionTree.this._rule);
                        if (gain > splitNode.splitScore) {
                            splitNode.splitFeature = j;
                            splitNode.quantitativeFeature = true;
                            splitNode.splitValue = (x_ij + this.prevx) / 2.0;
                            splitNode.splitScore = gain;
                        }
                        this.prevx = x_ij;
                        this.prevy = y_i;
                        int n4 = y_i;
                        trueCount[n4] = trueCount[n4] + numSamples;
                    }
                });
                int countDistinctX = replaceCount.get() + (countNaN.get() == 0 ? 0 : 1);
                if (countDistinctX <= 1) {
                    this.constFeatures = ArrayUtils.sortedArraySet(this.constFeatures, j);
                }
            }
            return splitNode;
        }

        public boolean split(@Nullable PriorityQueue<TrainNode> nextSplits) {
            if (this.node.splitFeature < 0) {
                throw new IllegalStateException("Split a node with invalid feature.");
            }
            IntPredicate goesLeft = this.getPredicate();
            double[] trueChildPosteriori = new double[DecisionTree.this._k];
            double[] falseChildPosteriori = new double[DecisionTree.this._k];
            MutableInt tc_ = new MutableInt(0);
            MutableInt fc_ = new MutableInt(0);
            int pivot = this.splitSamples(tc_, fc_, trueChildPosteriori, falseChildPosteriori, goesLeft);
            int tc = tc_.get();
            int fc = fc_.get();
            if (tc < DecisionTree.this._minSamplesLeaf || fc < DecisionTree.this._minSamplesLeaf) {
                this.node.markAsLeaf();
                return false;
            }
            int i = 0;
            while (i < DecisionTree.this._k) {
                int n = i;
                trueChildPosteriori[n] = trueChildPosteriori[n] / (double)tc;
                int n2 = i++;
                falseChildPosteriori[n2] = falseChildPosteriori[n2] / (double)fc;
            }
            this.partitionOrder(this.low, pivot, this.high, goesLeft);
            int leaves = 0;
            this.node.trueChild = new Node(trueChildPosteriori);
            TrainNode trueChild = new TrainNode(this.node.trueChild, this.depth + 1, this.low, pivot, tc, (int[])this.constFeatures.clone());
            this.node.falseChild = new Node(falseChildPosteriori);
            TrainNode falseChild = new TrainNode(this.node.falseChild, this.depth + 1, pivot, this.high, fc, this.constFeatures);
            this.constFeatures = null;
            if (tc >= DecisionTree.this._minSamplesSplit && trueChild.findBestSplit()) {
                if (nextSplits != null) {
                    nextSplits.add(trueChild);
                } else if (!trueChild.split(null)) {
                    ++leaves;
                }
            } else {
                ++leaves;
            }
            if (fc >= DecisionTree.this._minSamplesSplit && falseChild.findBestSplit()) {
                if (nextSplits != null) {
                    nextSplits.add(falseChild);
                } else if (!falseChild.split(null)) {
                    ++leaves;
                }
            } else {
                ++leaves;
            }
            if (leaves == 2 && this.node.trueChild.output == this.node.falseChild.output) {
                this.node.markAsLeaf();
                return false;
            }
            DecisionTree.this._importance.incr(this.node.splitFeature, this.node.splitScore);
            if (nextSplits == null) {
                this.node.posteriori = null;
            }
            return true;
        }

        private int splitSamples(@Nonnull MutableInt tc, @Nonnull MutableInt fc, @Nonnull double[] trueChildPosteriori, @Nonnull double[] falseChildPosteriori, @Nonnull IntPredicate goesLeft) {
            int[] sampleIndex = DecisionTree.this._sampleIndex;
            int[] samples = DecisionTree.this._samples;
            int[] y = DecisionTree.this._y;
            int pivot = this.low;
            int end = this.high;
            for (int k = this.low; k < end; ++k) {
                int i = sampleIndex[k];
                int numSamples = samples[i];
                int yi = y[i];
                if (goesLeft.test(i)) {
                    tc.addValue(numSamples);
                    int n = yi;
                    trueChildPosteriori[n] = trueChildPosteriori[n] + (double)numSamples;
                    ++pivot;
                    continue;
                }
                fc.addValue(numSamples);
                int n = yi;
                falseChildPosteriori[n] = falseChildPosteriori[n] + (double)numSamples;
            }
            return pivot;
        }

        private void partitionOrder(final int low, final int pivot, final int high, final @Nonnull IntPredicate goesLeft) {
            final int[] buf = new int[high - pivot];
            DecisionTree.this._order.eachRow(new Consumer(){

                @Override
                public void accept(int col, @Nonnull SparseIntArray row) {
                    DecisionTree.partitionArray(row, low, pivot, high, goesLeft, buf);
                }
            });
            DecisionTree.partitionArray(DecisionTree.this._sampleIndex, low, pivot, high, goesLeft, buf);
        }

        @Nonnull
        private IntPredicate getPredicate() {
            if (this.node.quantitativeFeature) {
                return new IntPredicate(){

                    @Override
                    public boolean test(int i) {
                        return DecisionTree.this._X.get(i, TrainNode.this.node.splitFeature, Double.NaN) <= TrainNode.this.node.splitValue;
                    }
                };
            }
            return new IntPredicate(){

                @Override
                public boolean test(int i) {
                    return DecisionTree.this._X.get(i, TrainNode.this.node.splitFeature, Double.NaN) == TrainNode.this.node.splitValue;
                }
            };
        }
    }

    public static final class Node
    implements Externalizable {
        int output = -1;
        @Nullable
        double[] posteriori = null;
        int splitFeature = -1;
        boolean quantitativeFeature = true;
        double splitValue = Double.NaN;
        double splitScore = 0.0;
        Node trueChild = null;
        Node falseChild = null;

        public Node() {
        }

        public Node(@Nonnull double[] posteriori) {
            this(Math.whichMax(posteriori), posteriori);
        }

        public Node(int output, @Nonnull double[] posteriori) {
            this.output = output;
            this.posteriori = posteriori;
        }

        private boolean isLeaf() {
            return this.trueChild == null && this.falseChild == null;
        }

        private void markAsLeaf() {
            this.splitFeature = -1;
            this.splitValue = Double.NaN;
            this.splitScore = 0.0;
            this.trueChild = null;
            this.falseChild = null;
        }

        @VisibleForTesting
        public int predict(@Nonnull double[] x) {
            return this.predict(new DenseVector(x));
        }

        public int predict(@Nonnull Vector x) {
            if (this.isLeaf()) {
                return this.output;
            }
            if (this.quantitativeFeature) {
                if (x.get(this.splitFeature, Double.NaN) <= this.splitValue) {
                    return this.trueChild.predict(x);
                }
                return this.falseChild.predict(x);
            }
            if (x.get(this.splitFeature, Double.NaN) == this.splitValue) {
                return this.trueChild.predict(x);
            }
            return this.falseChild.predict(x);
        }

        public void predict(@Nonnull Vector x, @Nonnull PredictionHandler handler) {
            if (this.isLeaf()) {
                handler.visitLeaf(this.output, this.posteriori);
            } else {
                double feature = x.get(this.splitFeature, Double.NaN);
                if (this.quantitativeFeature) {
                    if (feature <= this.splitValue) {
                        handler.visitBranch(PredictionHandler.Operator.LE, this.splitFeature, feature, this.splitValue);
                        this.trueChild.predict(x, handler);
                    } else {
                        handler.visitBranch(PredictionHandler.Operator.GT, this.splitFeature, feature, this.splitValue);
                        this.falseChild.predict(x, handler);
                    }
                } else if (feature == this.splitValue) {
                    handler.visitBranch(PredictionHandler.Operator.EQ, this.splitFeature, feature, this.splitValue);
                    this.trueChild.predict(x, handler);
                } else {
                    handler.visitBranch(PredictionHandler.Operator.NE, this.splitFeature, feature, this.splitValue);
                    this.falseChild.predict(x, handler);
                }
            }
        }

        public void exportJavascript(@Nonnull StringBuilder builder, @Nullable String[] featureNames, @Nullable String[] classNames, int depth) {
            if (this.isLeaf()) {
                DecisionTree.indent(builder, depth);
                builder.append("").append(SmileExtUtils.resolveName(this.output, classNames)).append(";\n");
            } else {
                DecisionTree.indent(builder, depth);
                if (this.quantitativeFeature) {
                    if (featureNames == null) {
                        builder.append("if( x[").append(this.splitFeature).append("] <= ").append(this.splitValue).append(" ) {\n");
                    } else {
                        builder.append("if( ").append(SmileExtUtils.resolveFeatureName(this.splitFeature, featureNames)).append(" <= ").append(this.splitValue).append(" ) {\n");
                    }
                } else if (featureNames == null) {
                    builder.append("if( x[").append(this.splitFeature).append("] == ").append(this.splitValue).append(" ) {\n");
                } else {
                    builder.append("if( ").append(SmileExtUtils.resolveFeatureName(this.splitFeature, featureNames)).append(" == ").append(this.splitValue).append(" ) {\n");
                }
                this.trueChild.exportJavascript(builder, featureNames, classNames, depth + 1);
                DecisionTree.indent(builder, depth);
                builder.append("} else  {\n");
                this.falseChild.exportJavascript(builder, featureNames, classNames, depth + 1);
                DecisionTree.indent(builder, depth);
                builder.append("}\n");
            }
        }

        public void exportGraphviz(@Nonnull StringBuilder builder, @Nullable String[] featureNames, @Nullable String[] classNames, @Nonnull String outputName, @Nullable double[] colorBrew, @Nonnull MutableInt nodeIdGenerator, int parentNodeId) {
            int myNodeId = nodeIdGenerator.getValue();
            if (this.isLeaf()) {
                String hsvColor = colorBrew == null || this.output >= colorBrew.length ? "#00000000" : String.format("%.4f,1.000,1.000", colorBrew[this.output]);
                builder.append(String.format(" %d [label=<%s = %s>, fillcolor=\"%s\", shape=ellipse];\n", myNodeId, outputName, SmileExtUtils.resolveName(this.output, classNames), hsvColor));
                if (myNodeId != parentNodeId) {
                    builder.append(' ').append(parentNodeId).append(" -> ").append(myNodeId);
                    if (parentNodeId == 0) {
                        if (myNodeId == 1) {
                            builder.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]");
                        } else {
                            builder.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]");
                        }
                    }
                    builder.append(";\n");
                }
            } else {
                if (this.quantitativeFeature) {
                    builder.append(String.format(" %d [label=<%s &le; %s>, fillcolor=\"#00000000\"];\n", myNodeId, SmileExtUtils.resolveFeatureName(this.splitFeature, featureNames), Double.toString(this.splitValue)));
                } else {
                    builder.append(String.format(" %d [label=<%s = %s>, fillcolor=\"#00000000\"];\n", myNodeId, SmileExtUtils.resolveFeatureName(this.splitFeature, featureNames), Double.toString(this.splitValue)));
                }
                if (myNodeId != parentNodeId) {
                    builder.append(' ').append(parentNodeId).append(" -> ").append(myNodeId);
                    if (parentNodeId == 0) {
                        if (myNodeId == 1) {
                            builder.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]");
                        } else {
                            builder.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]");
                        }
                    }
                    builder.append(";\n");
                }
                nodeIdGenerator.addValue(1);
                this.trueChild.exportGraphviz(builder, featureNames, classNames, outputName, colorBrew, nodeIdGenerator, myNodeId);
                nodeIdGenerator.addValue(1);
                this.falseChild.exportGraphviz(builder, featureNames, classNames, outputName, colorBrew, nodeIdGenerator, myNodeId);
            }
        }

        @Deprecated
        public int opCodegen(@Nonnull List<String> scripts, int depth) {
            int selfDepth = 0;
            StringBuilder buf = new StringBuilder();
            if (this.isLeaf()) {
                buf.append("push ").append(this.output);
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("goto last");
                scripts.add(buf.toString());
                selfDepth += 2;
            } else if (this.quantitativeFeature) {
                buf.append("push ").append("x[").append(this.splitFeature).append("]");
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("push ").append(this.splitValue);
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("ifle ");
                scripts.add(buf.toString());
                selfDepth += 3;
                int trueDepth = this.trueChild.opCodegen(scripts, depth += 3);
                selfDepth += trueDepth;
                scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth));
                int falseDepth = this.falseChild.opCodegen(scripts, depth + trueDepth);
                selfDepth += falseDepth;
            } else {
                buf.append("push ").append("x[").append(this.splitFeature).append("]");
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("push ").append(this.splitValue);
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("ifeq ");
                scripts.add(buf.toString());
                selfDepth += 3;
                int trueDepth = this.trueChild.opCodegen(scripts, depth += 3);
                selfDepth += trueDepth;
                scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth));
                int falseDepth = this.falseChild.opCodegen(scripts, depth + trueDepth);
                selfDepth += falseDepth;
            }
            return selfDepth;
        }

        @Override
        public void writeExternal(ObjectOutput out) throws IOException {
            out.writeInt(this.splitFeature);
            out.writeByte(this.quantitativeFeature ? 1 : 2);
            out.writeDouble(this.splitValue);
            if (this.isLeaf()) {
                out.writeBoolean(true);
                out.writeInt(this.output);
                out.writeInt(this.posteriori.length);
                for (int i = 0; i < this.posteriori.length; ++i) {
                    out.writeDouble(this.posteriori[i]);
                }
            } else {
                out.writeBoolean(false);
                if (this.trueChild == null) {
                    out.writeBoolean(false);
                } else {
                    out.writeBoolean(true);
                    this.trueChild.writeExternal(out);
                }
                if (this.falseChild == null) {
                    out.writeBoolean(false);
                } else {
                    out.writeBoolean(true);
                    this.falseChild.writeExternal(out);
                }
            }
        }

        @Override
        public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
            this.splitFeature = in.readInt();
            byte typeId = in.readByte();
            this.quantitativeFeature = typeId == 1;
            this.splitValue = in.readDouble();
            if (in.readBoolean()) {
                this.output = in.readInt();
                int size = in.readInt();
                double[] posteriori = new double[size];
                for (int i = 0; i < size; ++i) {
                    posteriori[i] = in.readDouble();
                }
                this.posteriori = posteriori;
            } else {
                if (in.readBoolean()) {
                    this.trueChild = new Node();
                    this.trueChild.readExternal(in);
                }
                if (in.readBoolean()) {
                    this.falseChild = new Node();
                    this.falseChild.readExternal(in);
                }
            }
        }
    }

    public static enum SplitRule {
        GINI,
        ENTROPY,
        CLASSIFICATION_ERROR;

    }
}

