/*
 * Decompiled with CFR 0.152.
 */
package hivemall.smile.regression;

import hivemall.annotations.VisibleForTesting;
import hivemall.smile.classification.PredictionHandler;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.smile.utils.VariableOrder;
import hivemall.utils.collections.arrays.SparseIntArray;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.function.Consumer;
import hivemall.utils.function.IntPredicate;
import hivemall.utils.lang.ArrayUtils;
import hivemall.utils.lang.ObjectUtils;
import hivemall.utils.lang.StringUtils;
import hivemall.utils.lang.mutable.MutableInt;
import hivemall.utils.random.PRNG;
import hivemall.utils.random.RandomNumberGeneratorFactory;
import hivemall.utils.sampling.IntReservoirSampler;
import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import matrix4j.matrix.Matrix;
import matrix4j.vector.DenseVector;
import matrix4j.vector.SparseVector;
import matrix4j.vector.Vector;
import matrix4j.vector.VectorProcedure;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.roaringbitmap.IntConsumer;
import org.roaringbitmap.RoaringBitmap;
import smile.math.Math;
import smile.regression.Regression;

public final class RegressionTree
implements Regression<Vector> {
    private static final Log logger = LogFactory.getLog(RegressionTree.class);
    private final Matrix _X;
    private final double[] _y;
    @Nonnull
    private final int[] _samples;
    @Nonnull
    private final VariableOrder _order;
    @Nonnull
    private final int[] _sampleIndex;
    @Nonnull
    private final RoaringBitmap _nominalAttrs;
    private final Vector _importance;
    private final Node _root;
    private final int _maxDepth;
    private final int _minSamplesSplit;
    private final int _minSamplesLeaf;
    private final int _numVars;
    private final PRNG _rnd;

    private static void indent(StringBuilder builder, int depth) {
        for (int i = 0; i < depth; ++i) {
            builder.append("  ");
        }
    }

    private static void partitionArray(@Nonnull SparseIntArray a, int low, int pivot, int high, @Nonnull IntPredicate goesLeft, @Nonnull int[] buf) {
        int i;
        int[] keys = a.keys();
        int[] values = a.values();
        int size = a.size();
        int startPos = ArrayUtils.insertionPoint(keys, size, low);
        int endPos = ArrayUtils.insertionPoint(keys, size, high);
        int pos = startPos;
        int k = 0;
        int j = 0;
        for (i = startPos; i < endPos; ++i) {
            int a_i = values[i];
            if (goesLeft.test(a_i)) {
                keys[pos] = low + j;
                values[pos] = a_i;
                ++pos;
                ++j;
                continue;
            }
            if (k >= buf.length) {
                throw new IndexOutOfBoundsException(String.format("low=%d, pivot=%d, high=%d, a.size()=%d, buf.length=%d, i=%d, j=%d, k=%d", low, pivot, high, a.size(), buf.length, i, j, k));
            }
            buf[k++] = a_i;
        }
        for (i = 0; i < k; ++i) {
            keys[pos] = pivot + i;
            values[pos] = buf[i];
            ++pos;
        }
        if (pos != endPos) {
            throw new IllegalStateException(String.format("pos=%d, startPos=%d, endPos=%d, k=%d", pos, startPos, endPos, k));
        }
    }

    private static void partitionArray(@Nonnull int[] a, int low, int pivot, int high, @Nonnull IntPredicate goesLeft, @Nonnull int[] buf) {
        int j = low;
        int k = 0;
        for (int i = low; i < high; ++i) {
            if (i >= a.length) {
                throw new IndexOutOfBoundsException(String.format("low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, i=%d, j=%d, k=%d", low, pivot, high, a.length, buf.length, i, j, k));
            }
            int a_i = a[i];
            if (goesLeft.test(a_i)) {
                a[j++] = a_i;
                continue;
            }
            if (k >= buf.length) {
                throw new IndexOutOfBoundsException(String.format("low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, i=%d, j=%d, k=%d", low, pivot, high, a.length, buf.length, i, j, k));
            }
            buf[k++] = a_i;
        }
        if (k != high - pivot || j != pivot) {
            throw new IndexOutOfBoundsException(String.format("low=%d, pivot=%d, high=%d, a.length=%d, buf.length=%d, j=%d, k=%d", low, pivot, high, a.length, buf.length, j, k));
        }
        System.arraycopy(buf, 0, a, pivot, k);
    }

    private static void pruneRedundantLeaves(@Nonnull Node node, @Nonnull Vector importance) {
        if (node.isLeaf()) {
            return;
        }
        RegressionTree.pruneRedundantLeaves(node.trueChild, importance);
        RegressionTree.pruneRedundantLeaves(node.falseChild, importance);
        if (node.trueChild.isLeaf() && node.falseChild.isLeaf() && node.trueChild.output == node.falseChild.output) {
            node.trueChild = null;
            node.falseChild = null;
            importance.decr(node.splitFeature, node.splitScore);
        }
    }

    public RegressionTree(@Nullable RoaringBitmap nominalAttrs, @Nonnull Matrix x, @Nonnull double[] y, int maxLeafs) {
        this(nominalAttrs, x, y, x.numColumns(), Integer.MAX_VALUE, maxLeafs, 5, 1, null, null);
    }

    public RegressionTree(@Nullable RoaringBitmap nominalAttrs, @Nonnull Matrix x, @Nonnull double[] y, int maxLeafs, @Nullable PRNG rand) {
        this(nominalAttrs, x, y, x.numColumns(), Integer.MAX_VALUE, maxLeafs, 5, 1, null, rand);
    }

    public RegressionTree(@Nullable RoaringBitmap nominalAttrs, @Nonnull Matrix x, @Nonnull double[] y, int numVars, int maxDepth, int maxLeafNodes, int minSamplesSplit, int minSamplesLeaf, @Nullable int[] samples, @Nullable PRNG rand) {
        this(nominalAttrs, x, y, numVars, maxDepth, maxLeafNodes, minSamplesSplit, minSamplesLeaf, samples, null, rand);
    }

    public RegressionTree(@Nullable RoaringBitmap nominalAttrs, @Nonnull Matrix x, @Nonnull double[] y, int numVars, int maxDepth, int maxLeafNodes, int minSamplesSplit, int minSamplesLeaf, @Nullable int[] samples, @Nullable NodeOutput output, @Nullable PRNG rand) {
        int[] sampleIndex;
        RegressionTree.checkArgument(x, y, numVars, maxDepth, maxLeafNodes, minSamplesSplit, minSamplesLeaf);
        this._X = x;
        this._y = y;
        if (nominalAttrs == null) {
            nominalAttrs = new RoaringBitmap();
        }
        this._nominalAttrs = nominalAttrs;
        this._numVars = numVars;
        this._maxDepth = maxDepth;
        if (minSamplesSplit < minSamplesLeaf * 2) {
            if (logger.isInfoEnabled()) {
                logger.info((Object)String.format("min_sample_leaf = %d replaces min_sample_split = %d with min_sample_split = %d", minSamplesLeaf, minSamplesSplit, minSamplesLeaf * 2));
            }
            minSamplesSplit = minSamplesLeaf * 2;
        }
        this._minSamplesSplit = minSamplesSplit;
        this._minSamplesLeaf = minSamplesLeaf;
        this._importance = x.isSparse() ? new SparseVector() : new DenseVector(x.numColumns());
        this._rnd = rand == null ? RandomNumberGeneratorFactory.createPRNG() : rand;
        int n = 0;
        double sum = 0.0;
        if (samples == null) {
            n = y.length;
            samples = new int[n];
            sampleIndex = new int[n];
            for (int i = 0; i < n; ++i) {
                samples[i] = 1;
                sum += y[i];
                sampleIndex[i] = i;
            }
        } else {
            IntArrayList positions = new IntArrayList(n);
            int end = y.length;
            for (int i = 0; i < end; ++i) {
                int sample = samples[i];
                if (sample == 0) continue;
                n += sample;
                sum += (double)sample * y[i];
                positions.add(i);
            }
            sampleIndex = positions.toArray(true);
        }
        this._samples = samples;
        this._order = SmileExtUtils.sort(nominalAttrs, x, samples);
        this._sampleIndex = sampleIndex;
        this._root = new Node(sum / (double)n);
        TrainNode trainRoot = new TrainNode(this._root, 1, 0, this._sampleIndex.length, n);
        if (maxLeafNodes == Integer.MAX_VALUE) {
            if (trainRoot.findBestSplit()) {
                trainRoot.split(null);
            }
        } else {
            TrainNode node;
            PriorityQueue<TrainNode> nextSplits = new PriorityQueue<TrainNode>();
            if (trainRoot.findBestSplit()) {
                nextSplits.add(trainRoot);
            }
            for (int leaves = 1; leaves < maxLeafNodes && (node = (TrainNode)nextSplits.poll()) != null; ++leaves) {
                if (node.split(nextSplits)) continue;
                --leaves;
            }
            RegressionTree.pruneRedundantLeaves(this._root, this._importance);
        }
        if (output != null) {
            trainRoot.calculateOutput(output);
        }
    }

    private static void checkArgument(@Nonnull Matrix x, @Nonnull double[] y, int numVars, int maxDepth, int maxLeafNodes, int minSamplesSplit, int minSamplesLeaf) {
        if (x.numRows() != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.numRows(), y.length));
        }
        if (y.length == 0) {
            throw new IllegalArgumentException("No training example given");
        }
        if (numVars <= 0 || numVars > x.numColumns()) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + numVars);
        }
        if (maxDepth < 2) {
            throw new IllegalArgumentException("maxDepth should be greater than 1: " + maxDepth);
        }
        if (maxLeafNodes < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + maxLeafNodes);
        }
        if (minSamplesSplit < 2) {
            throw new IllegalArgumentException("Invalid minimum number of samples required to split an internal node: " + minSamplesSplit);
        }
        if (minSamplesLeaf < 1) {
            throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + minSamplesLeaf);
        }
    }

    public Vector importance() {
        return this._importance;
    }

    @Override
    @VisibleForTesting
    public double predict(@Nonnull double[] x) {
        return this.predict(new DenseVector(x));
    }

    @Override
    public double predict(@Nonnull Vector x) {
        return this._root.predict(x);
    }

    @Nonnull
    public String predictJsCodegen(@Nonnull String[] featureNames) {
        StringBuilder buf = new StringBuilder(1024);
        this._root.exportJavascript(buf, featureNames, 0);
        return buf.toString();
    }

    @Deprecated
    @Nonnull
    public String predictOpCodegen(@Nonnull String sep) {
        ArrayList<String> opslist = new ArrayList<String>();
        this._root.opCodegen(opslist, 0);
        opslist.add("call end");
        String scripts = StringUtils.concat(opslist, sep);
        return scripts;
    }

    @Nonnull
    public byte[] serialize(boolean compress) throws HiveException {
        try {
            if (compress) {
                return ObjectUtils.toCompressedBytes(this._root);
            }
            return ObjectUtils.toBytes(this._root);
        }
        catch (IOException ioe) {
            throw new HiveException("IOException cause while serializing DecisionTree object", (Throwable)ioe);
        }
        catch (Exception e) {
            throw new HiveException("Exception cause while serializing DecisionTree object", (Throwable)e);
        }
    }

    @Nonnull
    public static Node deserialize(@Nonnull byte[] serializedObj, int length, boolean compressed) throws HiveException {
        Node root = new Node();
        try {
            if (compressed) {
                ObjectUtils.readCompressedObject(serializedObj, 0, length, root);
            } else {
                ObjectUtils.readObject(serializedObj, length, root);
            }
        }
        catch (IOException ioe) {
            throw new HiveException("IOException cause while deserializing DecisionTree object", (Throwable)ioe);
        }
        catch (Exception e) {
            throw new HiveException("Exception cause while deserializing DecisionTree object", (Throwable)e);
        }
        return root;
    }

    public String toString() {
        return this._root == null ? "" : this.predictJsCodegen(null);
    }

    private final class TrainNode
    implements Comparable<TrainNode> {
        @Nonnull
        final Node node;
        final int depth;
        final int low;
        final int high;
        final int samples;
        @Nullable
        TrainNode trueChild;
        @Nullable
        TrainNode falseChild;
        @Nullable
        int[] constFeatures;

        public TrainNode(Node node, int depth, int low, int high, int samples) {
            this(node, depth, low, high, samples, new int[0]);
        }

        public TrainNode(Node node, int depth, int low, int high, @Nonnull int samples, int[] constFeatures) {
            if (low >= high) {
                throw new IllegalArgumentException("Unexpected condition was met. low=" + low + ", high=" + high);
            }
            this.node = node;
            this.depth = depth;
            this.low = low;
            this.high = high;
            this.samples = samples;
            this.constFeatures = constFeatures;
        }

        @Override
        public int compareTo(TrainNode a) {
            return (int)Math.signum(a.node.splitScore - this.node.splitScore);
        }

        public void calculateOutput(NodeOutput output) {
            if (this.node.trueChild == null && this.node.falseChild == null) {
                int[] samples = this.getSamples();
                this.node.output = output.calculate(samples);
            } else {
                if (this.trueChild != null) {
                    this.trueChild.calculateOutput(output);
                }
                if (this.falseChild != null) {
                    this.falseChild.calculateOutput(output);
                }
            }
        }

        @Nonnull
        private int[] getSamples() {
            int size = this.high - this.low;
            IntArrayList result = new IntArrayList(size);
            int[] sampleIndex = RegressionTree.this._sampleIndex;
            int[] samples = RegressionTree.this._samples;
            int end = this.high;
            for (int i = this.low; i < end; ++i) {
                int index = sampleIndex[i];
                int sample = samples[index];
                if (sample <= 0) continue;
                result.add(index);
            }
            return result.toArray(true);
        }

        public boolean findBestSplit() {
            if (this.depth >= RegressionTree.this._maxDepth) {
                return false;
            }
            if (this.samples <= RegressionTree.this._minSamplesSplit) {
                return false;
            }
            int[] constFeatures_ = this.constFeatures;
            double sum = this.node.output * (double)this.samples;
            for (int varJ : this.variableIndex()) {
                if (ArrayUtils.contains(constFeatures_, varJ)) continue;
                Node split = this.findBestSplit(this.samples, sum, varJ);
                if (!(split.splitScore > this.node.splitScore)) continue;
                this.node.splitFeature = split.splitFeature;
                this.node.quantitativeFeature = split.quantitativeFeature;
                this.node.splitValue = split.splitValue;
                this.node.splitScore = split.splitScore;
                this.node.trueChildOutput = split.trueChildOutput;
                this.node.falseChildOutput = split.falseChildOutput;
            }
            return this.node.splitFeature != -1;
        }

        @Nonnull
        private int[] variableIndex() {
            Matrix X = RegressionTree.this._X;
            final IntReservoirSampler sampler = new IntReservoirSampler(RegressionTree.this._numVars, RegressionTree.this._rnd.nextLong());
            if (X.isSparse()) {
                final RoaringBitmap cols = new RoaringBitmap();
                VectorProcedure proc = new VectorProcedure(){

                    @Override
                    public void apply(int col) {
                        cols.add(col);
                    }
                };
                int[] sampleIndex = RegressionTree.this._sampleIndex;
                int end = this.high;
                for (int i = this.low; i < end; ++i) {
                    int row = sampleIndex[i];
                    X.eachColumnIndexInRow(row, proc);
                }
                cols.forEach(new IntConsumer(){

                    @Override
                    public void accept(int k) {
                        sampler.add(k);
                    }
                });
            } else {
                int ncols = X.numColumns();
                for (int i = 0; i < ncols; ++i) {
                    sampler.add(i);
                }
            }
            return sampler.getSample();
        }

        private Node findBestSplit(final int n, final double sum, final int j) {
            final int[] samples = RegressionTree.this._samples;
            int[] sampleIndex = RegressionTree.this._sampleIndex;
            Matrix X = RegressionTree.this._X;
            double[] y = RegressionTree.this._y;
            final Node split = new Node(0.0);
            if (RegressionTree.this._nominalAttrs.contains(j)) {
                Int2DoubleOpenHashMap trueSum = new Int2DoubleOpenHashMap();
                Int2IntOpenHashMap trueCount = new Int2IntOpenHashMap();
                int countNaN = 0;
                int end = this.high;
                for (int i = this.low; i < end; ++i) {
                    int index = sampleIndex[i];
                    int numSamples = samples[index];
                    if (numSamples == 0) continue;
                    double v = X.get(i, j, Double.NaN);
                    if (Double.isNaN(v)) {
                        ++countNaN;
                        continue;
                    }
                    int x_ij = (int)v;
                    trueSum.addTo(x_ij, y[i]);
                    trueCount.addTo(x_ij, 1);
                }
                int countDistinctX = trueCount.size() + (countNaN == 0 ? 0 : 1);
                if (countDistinctX <= 1) {
                    this.constFeatures = ArrayUtils.sortedArraySet(this.constFeatures, j);
                }
                for (Int2IntMap.Entry e : trueCount.int2IntEntrySet()) {
                    double falseMean;
                    double trueSum_k;
                    double trueMean;
                    double gain;
                    int k = e.getIntKey();
                    double tc = e.getIntValue();
                    double fc = (double)n - tc;
                    if (tc < (double)RegressionTree.this._minSamplesSplit || fc < (double)RegressionTree.this._minSamplesSplit || !((gain = tc * (trueMean = (trueSum_k = trueSum.get(k)) / tc) * trueMean + fc * (falseMean = (sum - trueSum_k) / fc) * falseMean - (double)n * split.output * split.output) > split.splitScore)) continue;
                    split.splitFeature = j;
                    split.quantitativeFeature = false;
                    split.splitValue = k;
                    split.splitScore = gain;
                    split.trueChildOutput = trueMean;
                    split.falseChildOutput = falseMean;
                }
            } else {
                final MutableInt countNaN = new MutableInt(0);
                final MutableInt replaceCount = new MutableInt(0);
                RegressionTree.this._order.eachNonNullInColumn(j, this.low, this.high, new Consumer(){
                    double trueSum = 0.0;
                    int trueCount = 0;
                    double prevx = Double.NaN;
                    double lastx = Double.NaN;

                    @Override
                    public void accept(int pos, int i) {
                        int numSamples = samples[i];
                        if (numSamples == 0) {
                            return;
                        }
                        double x_ij = RegressionTree.this._X.get(i, j, Double.NaN);
                        if (Double.isNaN(x_ij)) {
                            countNaN.incr();
                            return;
                        }
                        if (this.lastx != x_ij) {
                            this.lastx = x_ij;
                            replaceCount.incr();
                        }
                        double y_i = RegressionTree.this._y[i];
                        if (Double.isNaN(this.prevx) || x_ij == this.prevx) {
                            this.prevx = x_ij;
                            this.trueSum += (double)numSamples * y_i;
                            this.trueCount += numSamples;
                            return;
                        }
                        double falseCount = n - this.trueCount;
                        if (this.trueCount < RegressionTree.this._minSamplesSplit || falseCount < (double)RegressionTree.this._minSamplesSplit) {
                            this.prevx = x_ij;
                            this.trueSum += (double)numSamples * y_i;
                            this.trueCount += numSamples;
                            return;
                        }
                        double trueMean = this.trueSum / (double)this.trueCount;
                        double falseMean = (sum - this.trueSum) / falseCount;
                        double gain = (double)this.trueCount * trueMean * trueMean + falseCount * falseMean * falseMean - (double)n * split.output * split.output;
                        if (gain > split.splitScore) {
                            split.splitFeature = j;
                            split.quantitativeFeature = true;
                            split.splitValue = (x_ij + this.prevx) / 2.0;
                            split.splitScore = gain;
                            split.trueChildOutput = trueMean;
                            split.falseChildOutput = falseMean;
                        }
                        this.prevx = x_ij;
                        this.trueSum += (double)numSamples * y_i;
                        this.trueCount += numSamples;
                    }
                });
                int countDistinctX = replaceCount.get() + (countNaN.get() == 0 ? 0 : 1);
                if (countDistinctX <= 1) {
                    this.constFeatures = ArrayUtils.sortedArraySet(this.constFeatures, j);
                }
            }
            return split;
        }

        public boolean split(@Nullable PriorityQueue<TrainNode> nextSplits) {
            if (this.node.splitFeature < 0) {
                throw new IllegalStateException("Split a node with invalid feature.");
            }
            IntPredicate goesLeft = this.getPredicate();
            MutableInt tc_ = new MutableInt(0);
            MutableInt fc_ = new MutableInt(0);
            int pivot = this.splitSamples(tc_, fc_, goesLeft);
            int tc = tc_.get();
            int fc = fc_.get();
            if (tc < RegressionTree.this._minSamplesLeaf || fc < RegressionTree.this._minSamplesLeaf) {
                this.node.markAsLeaf();
                return false;
            }
            this.partitionOrder(this.low, pivot, this.high, goesLeft);
            int leaves = 0;
            this.node.trueChild = new Node(this.node.trueChildOutput);
            this.trueChild = new TrainNode(this.node.trueChild, this.depth + 1, this.low, pivot, tc, (int[])this.constFeatures.clone());
            this.node.falseChild = new Node(this.node.falseChildOutput);
            this.falseChild = new TrainNode(this.node.falseChild, this.depth + 1, pivot, this.high, fc, this.constFeatures);
            this.constFeatures = null;
            if (tc >= RegressionTree.this._minSamplesSplit && this.trueChild.findBestSplit()) {
                if (nextSplits != null) {
                    nextSplits.add(this.trueChild);
                } else if (!this.trueChild.split(null)) {
                    ++leaves;
                }
            } else {
                ++leaves;
            }
            if (fc >= RegressionTree.this._minSamplesSplit && this.falseChild.findBestSplit()) {
                if (nextSplits != null) {
                    nextSplits.add(this.falseChild);
                } else if (!this.falseChild.split(null)) {
                    ++leaves;
                }
            } else {
                ++leaves;
            }
            if (leaves == 2 && this.node.trueChild.output == this.node.falseChild.output) {
                this.node.markAsLeaf();
                return false;
            }
            RegressionTree.this._importance.incr(this.node.splitFeature, this.node.splitScore);
            return true;
        }

        private int splitSamples(@Nonnull MutableInt tc, @Nonnull MutableInt fc, @Nonnull IntPredicate goesLeft) {
            int[] sampleIndex = RegressionTree.this._sampleIndex;
            int[] samples = RegressionTree.this._samples;
            int pivot = this.low;
            int end = this.high;
            for (int k = this.low; k < end; ++k) {
                int i = sampleIndex[k];
                int numSamples = samples[i];
                if (goesLeft.test(i)) {
                    tc.addValue(numSamples);
                    ++pivot;
                    continue;
                }
                fc.addValue(numSamples);
            }
            return pivot;
        }

        private void partitionOrder(final int low, final int pivot, final int high, final @Nonnull IntPredicate goesLeft) {
            final int[] buf = new int[high - pivot];
            RegressionTree.this._order.eachRow(new Consumer(){

                @Override
                public void accept(int col, @Nonnull SparseIntArray row) {
                    RegressionTree.partitionArray(row, low, pivot, high, goesLeft, buf);
                }
            });
            RegressionTree.partitionArray(RegressionTree.this._sampleIndex, low, pivot, high, goesLeft, buf);
        }

        @Nonnull
        private IntPredicate getPredicate() {
            if (this.node.quantitativeFeature) {
                return new IntPredicate(){

                    @Override
                    public boolean test(int i) {
                        return RegressionTree.this._X.get(i, TrainNode.this.node.splitFeature, Double.NaN) <= TrainNode.this.node.splitValue;
                    }
                };
            }
            return new IntPredicate(){

                @Override
                public boolean test(int i) {
                    return RegressionTree.this._X.get(i, TrainNode.this.node.splitFeature, Double.NaN) == TrainNode.this.node.splitValue;
                }
            };
        }
    }

    public static final class Node
    implements Externalizable {
        double output = 0.0;
        int splitFeature = -1;
        boolean quantitativeFeature = true;
        double splitValue = Double.NaN;
        double splitScore = 0.0;
        Node trueChild;
        Node falseChild;
        double trueChildOutput = 0.0;
        double falseChildOutput = 0.0;

        public Node() {
        }

        public Node(double output) {
            this.output = output;
        }

        private boolean isLeaf() {
            return this.trueChild == null && this.falseChild == null;
        }

        private void markAsLeaf() {
            this.splitFeature = -1;
            this.splitValue = Double.NaN;
            this.splitScore = 0.0;
            this.trueChild = null;
            this.falseChild = null;
        }

        @VisibleForTesting
        public double predict(@Nonnull double[] x) {
            return this.predict(new DenseVector(x));
        }

        public double predict(@Nonnull Vector x) {
            if (this.isLeaf()) {
                return this.output;
            }
            if (this.quantitativeFeature) {
                if (x.get(this.splitFeature, Double.NaN) <= this.splitValue) {
                    return this.trueChild.predict(x);
                }
                return this.falseChild.predict(x);
            }
            if (x.get(this.splitFeature, Double.NaN) == this.splitValue) {
                return this.trueChild.predict(x);
            }
            return this.falseChild.predict(x);
        }

        public double predict(@Nonnull Vector x, @Nonnull PredictionHandler handler) {
            if (this.isLeaf()) {
                handler.visitLeaf(this.output);
                return this.output;
            }
            double feature = x.get(this.splitFeature, Double.NaN);
            if (this.quantitativeFeature) {
                if (feature <= this.splitValue) {
                    handler.visitBranch(PredictionHandler.Operator.LE, this.splitFeature, feature, this.splitValue);
                    return this.trueChild.predict(x);
                }
                handler.visitBranch(PredictionHandler.Operator.GT, this.splitFeature, feature, this.splitValue);
                return this.falseChild.predict(x);
            }
            if (feature == this.splitValue) {
                handler.visitBranch(PredictionHandler.Operator.EQ, this.splitFeature, feature, this.splitValue);
                return this.trueChild.predict(x);
            }
            handler.visitBranch(PredictionHandler.Operator.NE, this.splitFeature, feature, this.splitValue);
            return this.falseChild.predict(x);
        }

        public double predict(int[] x) {
            if (this.isLeaf()) {
                return this.output;
            }
            if (x[this.splitFeature] == (int)this.splitValue) {
                return this.trueChild.predict(x);
            }
            return this.falseChild.predict(x);
        }

        public void exportJavascript(@Nonnull StringBuilder builder, @Nullable String[] featureNames, int depth) {
            if (this.isLeaf()) {
                RegressionTree.indent(builder, depth);
                builder.append(this.output).append(";\n");
            } else if (this.quantitativeFeature) {
                RegressionTree.indent(builder, depth);
                if (featureNames == null) {
                    builder.append("if( x[").append(this.splitFeature).append("] <= ").append(this.splitValue).append(") {\n");
                } else {
                    builder.append("if( ").append(SmileExtUtils.resolveFeatureName(this.splitFeature, featureNames)).append(" <= ").append(this.splitValue).append(") {\n");
                }
                this.trueChild.exportJavascript(builder, featureNames, depth + 1);
                RegressionTree.indent(builder, depth);
                builder.append("} else {\n");
                this.falseChild.exportJavascript(builder, featureNames, depth + 1);
                RegressionTree.indent(builder, depth);
                builder.append("}\n");
            } else {
                RegressionTree.indent(builder, depth);
                if (featureNames == null) {
                    builder.append("if( x[").append(this.splitFeature).append("] == ").append(this.splitValue).append(") {\n");
                } else {
                    builder.append("if( ").append(SmileExtUtils.resolveFeatureName(this.splitFeature, featureNames)).append(" == ").append(this.splitValue).append(") {\n");
                }
                this.trueChild.exportJavascript(builder, featureNames, depth + 1);
                RegressionTree.indent(builder, depth);
                builder.append("} else {\n");
                this.falseChild.exportJavascript(builder, featureNames, depth + 1);
                RegressionTree.indent(builder, depth);
                builder.append("}\n");
            }
        }

        public void exportGraphviz(@Nonnull StringBuilder builder, @Nullable String[] featureNames, @Nonnull String outputName, @Nonnull MutableInt nodeIdGenerator, int parentNodeId) {
            int myNodeId = nodeIdGenerator.getValue();
            if (this.isLeaf()) {
                builder.append(String.format(" %d [label=<%s = %s>, fillcolor=\"#00000000\", shape=ellipse];\n", myNodeId, outputName, Double.toString(this.output)));
                if (myNodeId != parentNodeId) {
                    builder.append(' ').append(parentNodeId).append(" -> ").append(myNodeId);
                    if (parentNodeId == 0) {
                        if (myNodeId == 1) {
                            builder.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]");
                        } else {
                            builder.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]");
                        }
                    }
                    builder.append(";\n");
                }
            } else {
                if (this.quantitativeFeature) {
                    builder.append(String.format(" %d [label=<%s &le; %s>, fillcolor=\"#00000000\"];\n", myNodeId, SmileExtUtils.resolveFeatureName(this.splitFeature, featureNames), Double.toString(this.splitValue)));
                } else {
                    builder.append(String.format(" %d [label=<%s = %s>, fillcolor=\"#00000000\"];\n", myNodeId, SmileExtUtils.resolveFeatureName(this.splitFeature, featureNames), Double.toString(this.splitValue)));
                }
                if (myNodeId != parentNodeId) {
                    builder.append(' ').append(parentNodeId).append(" -> ").append(myNodeId);
                    if (parentNodeId == 0) {
                        if (myNodeId == 1) {
                            builder.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]");
                        } else {
                            builder.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]");
                        }
                    }
                    builder.append(";\n");
                }
                nodeIdGenerator.addValue(1);
                this.trueChild.exportGraphviz(builder, featureNames, outputName, nodeIdGenerator, myNodeId);
                nodeIdGenerator.addValue(1);
                this.falseChild.exportGraphviz(builder, featureNames, outputName, nodeIdGenerator, myNodeId);
            }
        }

        @Deprecated
        public int opCodegen(@Nonnull List<String> scripts, int depth) {
            int selfDepth = 0;
            StringBuilder buf = new StringBuilder();
            if (this.isLeaf()) {
                buf.append("push ").append(this.output);
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("goto last");
                scripts.add(buf.toString());
                selfDepth += 2;
            } else if (this.quantitativeFeature) {
                buf.append("push ").append("x[").append(this.splitFeature).append("]");
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("push ").append(this.splitValue);
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("ifle ");
                scripts.add(buf.toString());
                selfDepth += 3;
                int trueDepth = this.trueChild.opCodegen(scripts, depth += 3);
                selfDepth += trueDepth;
                scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth));
                int falseDepth = this.falseChild.opCodegen(scripts, depth + trueDepth);
                selfDepth += falseDepth;
            } else {
                buf.append("push ").append("x[").append(this.splitFeature).append("]");
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("push ").append(this.splitValue);
                scripts.add(buf.toString());
                buf.setLength(0);
                buf.append("ifeq ");
                scripts.add(buf.toString());
                selfDepth += 3;
                int trueDepth = this.trueChild.opCodegen(scripts, depth += 3);
                selfDepth += trueDepth;
                scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth));
                int falseDepth = this.falseChild.opCodegen(scripts, depth + trueDepth);
                selfDepth += falseDepth;
            }
            return selfDepth;
        }

        @Override
        public void writeExternal(ObjectOutput out) throws IOException {
            out.writeInt(this.splitFeature);
            out.writeByte(this.quantitativeFeature ? 1 : 2);
            out.writeDouble(this.splitValue);
            if (this.isLeaf()) {
                out.writeBoolean(true);
                out.writeDouble(this.output);
            } else {
                out.writeBoolean(false);
                if (this.trueChild == null) {
                    out.writeBoolean(false);
                } else {
                    out.writeBoolean(true);
                    this.trueChild.writeExternal(out);
                }
                if (this.falseChild == null) {
                    out.writeBoolean(false);
                } else {
                    out.writeBoolean(true);
                    this.falseChild.writeExternal(out);
                }
            }
        }

        @Override
        public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
            this.splitFeature = in.readInt();
            byte typeId = in.readByte();
            this.quantitativeFeature = typeId == 1;
            this.splitValue = in.readDouble();
            if (in.readBoolean()) {
                this.output = in.readDouble();
            } else {
                if (in.readBoolean()) {
                    this.trueChild = new Node();
                    this.trueChild.readExternal(in);
                }
                if (in.readBoolean()) {
                    this.falseChild = new Node();
                    this.falseChild.readExternal(in);
                }
            }
        }
    }

    public static interface NodeOutput {
        public double calculate(int[] var1);
    }
}

