/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.drf;

import hex.ModelBuilder;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.schemas.DRFV3;
import hex.tree.DHistogram;
import hex.tree.DTree;
import hex.tree.ScoreBuildHistogram;
import hex.tree.SharedTree;
import hex.tree.drf.DRFModel;
import hex.tree.drf.OOBScorer;
import hex.tree.drf.TreeMeasuresCollector;
import java.util.Arrays;
import java.util.Random;
import water.AutoBuffer;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;
import water.util.RandomUtils;
import water.util.Timer;

public class DRF
extends SharedTree<DRFModel, DRFModel.DRFParameters, DRFModel.DRFOutput> {
    protected int _mtry;
    protected long _actual_seed;
    static final boolean DEBUG_DETERMINISTIC = false;

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    public DRF(DRFModel.DRFParameters parms) {
        super("DRF", parms);
        this.init(false);
    }

    public DRFV3 schema() {
        return new DRFV3();
    }

    public Job<DRFModel> trainModel() {
        return this.start(new DRFDriver(), ((DRFModel.DRFParameters)this._parms)._ntrees);
    }

    public Vec vresponse() {
        return super.vresponse() == null ? this.response() : super.vresponse();
    }

    @Override
    public void init(boolean expensive) {
        super.init(expensive);
        if (!(0.0 < (double)((DRFModel.DRFParameters)this._parms)._sample_rate) || !((double)((DRFModel.DRFParameters)this._parms)._sample_rate <= 1.0)) {
            throw new IllegalArgumentException("Sample rate should be interval (0,1> but it is " + ((DRFModel.DRFParameters)this._parms)._sample_rate);
        }
        this._actual_seed = ((DRFModel.DRFParameters)this._parms)._seed == -1L ? RandomUtils.getRNG((long[])new long[]{-3278530047320914430L}).nextLong() : ((DRFModel.DRFParameters)this._parms)._seed;
        if (((DRFModel.DRFParameters)this._parms)._mtries < 1 && ((DRFModel.DRFParameters)this._parms)._mtries != -1) {
            this.error("_mtries", "mtries must be -1 (converted to sqrt(features)), or >= 1 but it is " + ((DRFModel.DRFParameters)this._parms)._mtries);
        }
        if (this._train != null) {
            int ncols = this._train.numCols();
            if (((DRFModel.DRFParameters)this._parms)._mtries != -1 && (1 > ((DRFModel.DRFParameters)this._parms)._mtries || ((DRFModel.DRFParameters)this._parms)._mtries >= ncols)) {
                this.error("_mtries", "Computed mtries should be -1 or in interval <1,#cols> but it is " + ((DRFModel.DRFParameters)this._parms)._mtries);
            }
        }
        if (((DRFModel.DRFParameters)this._parms)._sample_rate == 1.0f && this._valid == null) {
            this.error("_sample_rate", "Sample rate is 100% and no validation dataset is specified.  There are no OOB data to compute out-of-bag error estimation!");
        }
    }

    @Override
    protected DTree.DecidedNode makeDecided(DTree.UndecidedNode udn, DHistogram[] hs) {
        return new DRFDecidedNode(udn, hs);
    }

    @Override
    protected double score1(Chunk[] chks, double[] fs, int row) {
        double sum = 0.0;
        if (this._nclass > 1) {
            for (int k = 0; k < this._nclass; ++k) {
                double d = this.chk_tree(chks, k).atd(row);
                fs[k + 1] = d;
                sum += d;
            }
        } else {
            fs[0] = this.chk_tree(chks, 0).atd(row) / this.chk_oobt(chks).atd(row);
            sum += fs[0];
            fs[1] = 0.0;
        }
        return sum;
    }

    static class Sample
    extends MRTask<Sample> {
        final DRFTree _tree;
        final float _rate;

        Sample(DRFTree tree, float rate) {
            this._tree = tree;
            this._rate = rate;
        }

        public void map(Chunk nids, Chunk ys) {
            Random rand = this._tree.rngForChunk(nids.cidx());
            for (int row = 0; row < nids._len; ++row) {
                if (!(rand.nextFloat() >= this._rate) && !Double.isNaN(ys.atd(row))) continue;
                nids.set(row, -2L);
            }
        }
    }

    static class DRFLeafNode
    extends DTree.LeafNode {
        DRFLeafNode(DTree tree, int pid) {
            super(tree, pid);
        }

        DRFLeafNode(DTree tree, int pid, int nid) {
            super(tree, pid, nid);
        }

        @Override
        protected AutoBuffer compress(AutoBuffer ab) {
            assert (!Double.isNaN(this._pred));
            return ab.put4f(this._pred);
        }

        @Override
        protected int size() {
            return 4;
        }
    }

    static class DRFUndecidedNode
    extends DTree.UndecidedNode {
        DRFUndecidedNode(DTree tree, int pid, DHistogram[] hs) {
            super(tree, pid, hs);
        }

        @Override
        public int[] scoreCols(DHistogram[] hs) {
            DRFTree tree = (DRFTree)this._tree;
            int[] cols = new int[hs.length];
            int len = 0;
            for (int i = 0; i < hs.length; ++i) {
                if (hs[i] == null) continue;
                assert (hs[i]._min < hs[i]._maxEx && hs[i].nbins() > 1) : "broken histo range " + (Object)((Object)hs[i]);
                cols[len++] = i;
            }
            int choices = len;
            assert (choices > 0);
            for (int i = 0; i < tree._mtrys && len != 0; ++i) {
                int idx2 = tree._rand.nextInt(len);
                int col = cols[idx2];
                cols[idx2] = cols[--len];
                cols[len] = col;
            }
            assert (choices - len > 0);
            return Arrays.copyOfRange(cols, len, choices);
        }
    }

    static class DRFDecidedNode
    extends DTree.DecidedNode {
        DRFDecidedNode(DTree.UndecidedNode n, DHistogram[] hs) {
            super(n, hs);
        }

        @Override
        public DTree.UndecidedNode makeUndecidedNode(DHistogram[] hs) {
            return new DRFUndecidedNode(this._tree, this._nid, hs);
        }

        @Override
        public DTree.Split bestCol(DTree.UndecidedNode u, DHistogram[] hs) {
            DTree.Split best = new DTree.Split(-1, -1, null, 0, Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE, 0L, 0L, 0.0, 0.0);
            if (hs == null) {
                return best;
            }
            for (int i = 0; i < u._scoreCols.length; ++i) {
                int col = u._scoreCols[i];
                DTree.Split s = hs[col].scoreMSE(col, this._tree._min_rows);
                if (s == null) continue;
                if (s.se() < best.se()) {
                    best = s;
                }
                if (s.se() <= 0.0) break;
            }
            return best;
        }
    }

    private class DRFDriver
    extends SharedTree.Driver {
        protected int _ntreesFromCheckpoint;
        public transient TreeMeasuresCollector.TreeMeasures _treeMeasuresOnOOB;
        public transient TreeMeasuresCollector.TreeMeasures[] _treeMeasuresOnSOOB;
        private transient float[] _improvPerVar;

        private DRFDriver() {
        }

        private void initTreeMeasurements() {
            this._improvPerVar = new float[DRF.this._ncols];
            int ntrees = ((DRFModel.DRFParameters)DRF.this._parms)._ntrees;
            if (((DRFModel.DRFOutput)((DRFModel)((DRF)DRF.this)._model)._output).isClassifier()) {
                this._treeMeasuresOnOOB = new TreeMeasuresCollector.TreeVotes(ntrees);
                this._treeMeasuresOnSOOB = new TreeMeasuresCollector.TreeVotes[DRF.this._ncols];
                for (int i = 0; i < DRF.this._ncols; ++i) {
                    this._treeMeasuresOnSOOB[i] = new TreeMeasuresCollector.TreeVotes(ntrees);
                }
            } else {
                this._treeMeasuresOnOOB = new TreeMeasuresCollector.TreeSSE(ntrees);
                this._treeMeasuresOnSOOB = new TreeMeasuresCollector.TreeSSE[DRF.this._ncols];
                for (int i = 0; i < DRF.this._ncols; ++i) {
                    this._treeMeasuresOnSOOB[i] = new TreeMeasuresCollector.TreeSSE(ntrees);
                }
            }
        }

        @Override
        protected void buildModel() {
            int n = ((DRFModel.DRFParameters)DRF.this._parms)._mtries == -1 ? (DRF.this.isClassifier() ? Math.max((int)Math.sqrt(DRF.this._ncols), 1) : Math.max(DRF.this._ncols / 3, 1)) : (DRF.this._mtry = ((DRFModel.DRFParameters)DRF.this._parms)._mtries);
            if (1 > DRF.this._mtry || DRF.this._mtry > DRF.this._ncols) {
                throw new IllegalArgumentException("Computed mtry should be in interval <1,#cols> but it is " + DRF.this._mtry);
            }
            this.initTreeMeasurements();
            DRF.this._train.add("OUT_BAG_TREES", DRF.this._response.makeZero());
            new SetWrkTask().doAll(DRF.this._train);
            if (DRF.this._valid == null && ((DRFModel.DRFParameters)DRF.this._parms)._checkpoint) {
                Timer t = new Timer();
                new OOBScorer(DRF.this._ncols, DRF.this._nclass, ((DRFModel.DRFParameters)DRF.this._parms)._sample_rate, ((DRFModel.DRFOutput)((DRFModel)((DRF)DRF.this)._model)._output)._treeKeys).doAll(DRF.this._train);
                Log.info((Object[])new Object[]{"Reconstructing oob stats from checkpointed model took " + t});
            }
            Random rand = SharedTree.createRNG(DRF.this._actual_seed);
            for (int i = 0; i < this._ntreesFromCheckpoint; ++i) {
                rand.nextLong();
            }
            Object ktrees = null;
            for (int tid = 0; tid < ((DRFModel.DRFParameters)DRF.this._parms)._ntrees; ++tid) {
                double training_r2;
                if ((tid != 0 || !((DRFModel.DRFParameters)DRF.this._parms)._checkpoint) && (training_r2 = DRF.this.doScoringAndSaveModel(false, DRF.this._valid == null, ((DRFModel.DRFParameters)DRF.this._parms)._build_tree_one_node)) >= ((DRFModel.DRFParameters)DRF.this._parms)._r2_stopping) {
                    return;
                }
                Timer kb_timer = new Timer();
                this.buildNextKTrees(DRF.this._train, DRF.this._mtry, ((DRFModel.DRFParameters)DRF.this._parms)._sample_rate, rand, tid);
                Log.info((Object[])new Object[]{tid + 1 + ". tree was built " + kb_timer.toString()});
                DRF.this.update(1L);
                if (DRF.this.isRunning()) continue;
                return;
            }
            DRF.this.doScoringAndSaveModel(true, DRF.this._valid == null, ((DRFModel.DRFParameters)DRF.this._parms)._build_tree_one_node);
        }

        private void buildNextKTrees(Frame fr, int mtrys, float sample_rate, Random rand, int tid) {
            int k;
            DTree[] ktrees = new DTree[DRF.this._nclass];
            DHistogram[][][] hcs = new DHistogram[DRF.this._nclass][1][DRF.this._ncols];
            int adj_nbins = Math.max(1024, ((DRFModel.DRFParameters)DRF.this._parms)._nbins);
            long[] _distribution = ((DRFModel.DRFOutput)((DRFModel)((DRF)DRF.this)._model)._output)._distribution;
            long rseed = rand.nextLong();
            for (int k2 = 0; k2 < DRF.this._nclass; ++k2) {
                if (_distribution[k2] == 0L) continue;
                ktrees[k2] = new DRFTree(fr, DRF.this._ncols, (char)((DRFModel.DRFParameters)DRF.this._parms)._nbins, (char)DRF.this._nclass, ((DRFModel.DRFParameters)DRF.this._parms)._min_rows, mtrys, rseed);
                boolean isBinom = DRF.this.isClassifier();
                new DRFUndecidedNode(ktrees[k2], -1, DHistogram.initialHist(fr, DRF.this._ncols, adj_nbins, hcs[k2][0], isBinom));
            }
            Timer t_1 = new Timer();
            Sample[] ss = new Sample[DRF.this._nclass];
            for (k = 0; k < DRF.this._nclass; ++k) {
                if (ktrees[k] == null) continue;
                ss[k] = (Sample)new Sample((DRFTree)ktrees[k], sample_rate).dfork(0, new Frame(new Vec[]{DRF.this.vec_nids(fr, k), DRF.this.vec_resp(fr)}), ((DRFModel.DRFParameters)DRF.this._parms)._build_tree_one_node);
            }
            for (k = 0; k < DRF.this._nclass; ++k) {
                if (ss[k] == null) continue;
                ss[k].getResult();
            }
            Log.debug((Object[])new Object[]{"Sampling took: + " + t_1});
            int[] leafs = new int[DRF.this._nclass];
            Timer t_2 = new Timer();
            for (int depth = 0; depth < ((DRFModel.DRFParameters)DRF.this._parms)._max_depth; ++depth) {
                if (!DRF.this.isRunning()) {
                    return;
                }
                if ((hcs = DRF.this.buildLayer(fr, ((DRFModel.DRFParameters)DRF.this._parms)._nbins, ktrees, leafs, hcs, true, ((DRFModel.DRFParameters)DRF.this._parms)._build_tree_one_node)) == null) break;
            }
            Log.debug((Object[])new Object[]{"Tree build took: " + t_2});
            Timer t_3 = new Timer();
            for (int k3 = 0; k3 < DRF.this._nclass; ++k3) {
                DTree tree = ktrees[k3];
                if (tree == null) continue;
                int leaf = leafs[k3] = tree.len();
                for (int nid = 0; nid < leaf; ++nid) {
                    if (!(tree.node(nid) instanceof DTree.DecidedNode)) continue;
                    DTree.DecidedNode dn = tree.decided(nid);
                    if (dn._split._col == -1) {
                        if (nid != 0) continue;
                        DRFLeafNode ln = new DRFLeafNode(tree, -1, 0);
                        ln._pred = (float)(DRF.this.isClassifier() ? ((DRFModel.DRFOutput)((DRFModel)((DRF)DRF.this)._model)._output)._priorClassDist[k3] : DRF.this._response.mean());
                        continue;
                    }
                    for (int i = 0; i < dn._nids.length; ++i) {
                        int cnid = dn._nids[i];
                        if (cnid != -1 && !(tree.node(cnid) instanceof DTree.UndecidedNode) && (!(tree.node(cnid) instanceof DTree.DecidedNode) || ((DTree.DecidedNode)tree.node((int)cnid))._split.col() != -1)) continue;
                        DRFLeafNode ln = new DRFLeafNode(tree, nid);
                        ln._pred = (float)dn.pred(i);
                        dn._nids[i] = ln.nid();
                    }
                }
            }
            Log.debug((Object[])new Object[]{"Nodes propagation: " + t_3});
            Timer t_4 = new Timer();
            CollectPreds cp = (CollectPreds)new CollectPreds(ktrees, leafs, ((DRFModel)DRF.this._model).defaultThreshold()).doAll(fr, ((DRFModel.DRFParameters)DRF.this._parms)._build_tree_one_node);
            if (DRF.this.isClassifier()) {
                TreeMeasuresCollector.asVotes(this._treeMeasuresOnOOB).append(cp.rightVotes, cp.allRows);
            } else {
                TreeMeasuresCollector.asSSE(this._treeMeasuresOnOOB).append(cp.sse, cp.allRows);
            }
            Log.debug((Object[])new Object[]{"CollectPreds done: " + t_4});
            ((DRFModel.DRFOutput)((DRFModel)((DRF)DRF.this)._model)._output).addKTrees(ktrees);
        }

        protected DRFModel makeModel(Key modelKey, DRFModel.DRFParameters parms, double mse_train, double mse_valid) {
            return new DRFModel(modelKey, parms, new DRFModel.DRFOutput(DRF.this, mse_train, mse_valid));
        }

        private class CollectPreds
        extends MRTask<CollectPreds> {
            final DTree[] _trees;
            double _threshold;
            long rightVotes;
            long allRows;
            float sse;
            final boolean importance = true;

            CollectPreds(DTree[] trees, int[] leafs, double threshold) {
                this._trees = trees;
                this._threshold = threshold;
            }

            public void map(Chunk[] chks) {
                Chunk y = DRF.this.chk_resp(chks);
                double[] rpred = new double[1 + DRF.this._nclass];
                double[] rowdata = new double[DRF.this._ncols];
                Chunk oobt = DRF.this.chk_oobt(chks);
                for (int row = 0; row < oobt._len; ++row) {
                    boolean wasOOBRow = false;
                    for (int k = 0; k < DRF.this._nclass; ++k) {
                        DTree tree = this._trees[k];
                        if (tree == null) continue;
                        Chunk ct = DRF.this.chk_tree(chks, k);
                        Chunk nids = DRF.this.chk_nids(chks, k);
                        int nid = (int)nids.at8(row);
                        if (ScoreBuildHistogram.isOOBRow(nid)) {
                            int leafnid;
                            assert (k == 0 || wasOOBRow) : "Something is wrong: k-class trees oob row computing is broken! All k-trees should agree on oob row!";
                            wasOOBRow = true;
                            if (tree.node(nid = ScoreBuildHistogram.oob2Nid(nid)) instanceof DTree.UndecidedNode) {
                                nid = tree.node(nid).pid();
                            }
                            if (tree.root() instanceof DTree.LeafNode) {
                                leafnid = 0;
                            } else {
                                DTree.DecidedNode dn = tree.decided(nid);
                                if (dn._split.col() == -1) {
                                    dn = tree.decided(tree.node(nid).pid());
                                }
                                leafnid = dn.ns(chks, row);
                            }
                            double prediction = ((DTree.LeafNode)tree.node(leafnid)).pred();
                            rpred[1 + k] = (float)prediction;
                            ct.set(row, (float)(ct.atd(row) + prediction));
                            oobt.set(row, DRF.this._nclass > 1 ? 1.0 : oobt.atd(row) + 1.0);
                        }
                        nids.set(row, 0L);
                    }
                    if (!wasOOBRow || y.isNA(row)) continue;
                    if (DRF.this.isClassifier()) {
                        int actuPred;
                        int treePred = GenModel.getPrediction((double[])rpred, (double[])DRF.this.data_row(chks, row, rowdata), (double)this._threshold);
                        if (treePred == (actuPred = (int)y.at8(row))) {
                            ++this.rightVotes;
                        }
                    } else {
                        double treePred = rpred[1];
                        double actuPred = y.atd(row);
                        this.sse = (float)((double)this.sse + (actuPred - treePred) * (actuPred - treePred));
                    }
                    ++this.allRows;
                }
            }

            public void reduce(CollectPreds mrt) {
                this.rightVotes += mrt.rightVotes;
                this.allRows += mrt.allRows;
                this.sse += mrt.sse;
            }
        }
    }

    private class SetWrkTask
    extends MRTask<SetWrkTask> {
        private SetWrkTask() {
        }

        public void map(Chunk[] chks) {
            Chunk cy = DRF.this.chk_resp(chks);
            for (int i = 0; i < cy._len; ++i) {
                if (cy.isNA(i)) continue;
                if (DRF.this.isClassifier()) {
                    int cls = (int)cy.at8(i);
                    DRF.this.chk_work(chks, cls).set(i, 1L);
                    continue;
                }
                float pred = (float)cy.atd(i);
                DRF.this.chk_work(chks, 0).set(i, pred);
            }
        }
    }

    static class DRFTree
    extends DTree {
        final int _mtrys;
        final long[] _seeds;
        final transient Random _rand;

        DRFTree(Frame fr, int ncols, char nbins, char nclass, int min_rows, int mtrys, long seed) {
            super(fr._names, ncols, nbins, nclass, min_rows, seed);
            this._mtrys = mtrys;
            this._rand = SharedTree.createRNG(seed);
            this._seeds = new long[fr.vecs()[0].nChunks()];
            for (int i = 0; i < this._seeds.length; ++i) {
                this._seeds[i] = this._rand.nextLong();
            }
        }

        public Random rngForChunk(int cidx) {
            long seed = this._seeds[cidx];
            return SharedTree.createRNG(seed);
        }
    }
}

