/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.drf;

import hex.Distribution;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.schemas.DRFV3;
import hex.tree.DHistogram;
import hex.tree.DTree;
import hex.tree.Sample;
import hex.tree.ScoreBuildHistogram;
import hex.tree.SharedTree;
import hex.tree.drf.DRFModel;
import hex.tree.drf.TreeMeasuresCollector;
import java.util.Random;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;

public class DRF
extends SharedTree<DRFModel, DRFModel.DRFParameters, DRFModel.DRFOutput> {
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    public DRF(DRFModel.DRFParameters parms) {
        super("DRF", parms);
        this.init(false);
    }

    public DRFV3 schema() {
        return new DRFV3();
    }

    protected Job<DRFModel> trainModelImpl(long work, boolean restartTimer) {
        return this.start(new DRFDriver(), work, restartTimer);
    }

    @Override
    public void init(boolean expensive) {
        super.init(expensive);
        if (((DRFModel.DRFParameters)this._parms)._mtries < 1 && ((DRFModel.DRFParameters)this._parms)._mtries != -1) {
            this.error("_mtries", "mtries must be -1 (converted to sqrt(features)), or >= 1 but it is " + ((DRFModel.DRFParameters)this._parms)._mtries);
        }
        if (this._train != null) {
            int ncols = this._train.numCols();
            if (((DRFModel.DRFParameters)this._parms)._mtries != -1 && (1 > ((DRFModel.DRFParameters)this._parms)._mtries || ((DRFModel.DRFParameters)this._parms)._mtries >= ncols)) {
                this.error("_mtries", "Computed mtries should be -1 or in interval [1," + ncols + "[ but it is " + ((DRFModel.DRFParameters)this._parms)._mtries);
            }
        }
        if (((DRFModel.DRFParameters)this._parms)._distribution == Distribution.Family.AUTO) {
            if (this._nclass == 1) {
                ((DRFModel.DRFParameters)this._parms)._distribution = Distribution.Family.gaussian;
            }
            if (this._nclass >= 2) {
                ((DRFModel.DRFParameters)this._parms)._distribution = Distribution.Family.multinomial;
            }
        }
        if (((DRFModel.DRFParameters)this._parms)._sample_rate == 1.0f && this._valid == null) {
            this.error("_sample_rate", "Sample rate is 100% and no validation dataset is specified.  There are no OOB data to compute out-of-bag error estimation!");
        }
        if (this.hasOffsetCol()) {
            this.error("_offset_column", "Offsets are not yet supported for DRF.");
        }
        if (this.hasOffsetCol() && this.isClassifier()) {
            this.error("_offset_column", "Offset is only supported for regression.");
        }
    }

    @Override
    protected double score1(Chunk[] chks, double weight, double offset, double[] fs, int row) {
        double sum = 0.0;
        if (this._nclass > 2 || this._nclass == 2 && !((DRFModel)this._model).binomialOpt()) {
            for (int k = 0; k < this._nclass; ++k) {
                double d = this.chk_tree(chks, k).atd(row) / this.chk_oobt(chks).atd(row);
                fs[k + 1] = d;
                sum += d;
            }
        } else if (this._nclass == 2 && ((DRFModel)this._model).binomialOpt()) {
            fs[1] = this.chk_tree(chks, 0).atd(row) / this.chk_oobt(chks).atd(row);
            assert (fs[1] >= 0.0 && fs[1] <= 1.0);
            fs[2] = 1.0 - fs[1];
        } else {
            fs[0] = this.chk_tree(chks, 0).atd(row) / this.chk_oobt(chks).atd(row);
            sum += fs[0];
            fs[1] = 0.0;
        }
        return sum;
    }

    private class DRFDriver
    extends SharedTree.Driver {
        public transient TreeMeasuresCollector.TreeMeasures _treeMeasuresOnOOB;
        public transient TreeMeasuresCollector.TreeMeasures[] _treeMeasuresOnSOOB;
        private transient float[] _improvPerVar;

        private DRFDriver() {
        }

        @Override
        protected boolean doOOBScoring() {
            return true;
        }

        private void initTreeMeasurements() {
            this._improvPerVar = new float[DRF.this._ncols];
            int ntrees = ((DRFModel.DRFParameters)DRF.this._parms)._ntrees;
            if (((DRFModel.DRFOutput)((DRFModel)((DRF)DRF.this)._model)._output).isClassifier()) {
                this._treeMeasuresOnOOB = new TreeMeasuresCollector.TreeVotes(ntrees);
                this._treeMeasuresOnSOOB = new TreeMeasuresCollector.TreeVotes[DRF.this._ncols];
                for (int i = 0; i < DRF.this._ncols; ++i) {
                    this._treeMeasuresOnSOOB[i] = new TreeMeasuresCollector.TreeVotes(ntrees);
                }
            } else {
                this._treeMeasuresOnOOB = new TreeMeasuresCollector.TreeSSE(ntrees);
                this._treeMeasuresOnSOOB = new TreeMeasuresCollector.TreeSSE[DRF.this._ncols];
                for (int i = 0; i < DRF.this._ncols; ++i) {
                    this._treeMeasuresOnSOOB[i] = new TreeMeasuresCollector.TreeSSE(ntrees);
                }
            }
        }

        @Override
        protected void initializeModelSpecifics() {
            DRF.this._mtry = ((DRFModel.DRFParameters)DRF.this._parms)._mtries == -1 ? (DRF.this.isClassifier() ? Math.max((int)Math.sqrt(DRF.this._ncols), 1) : Math.max(DRF.this._ncols / 3, 1)) : ((DRFModel.DRFParameters)DRF.this._parms)._mtries;
            if (1 > DRF.this._mtry || DRF.this._mtry > DRF.this._ncols) {
                throw new IllegalArgumentException("Computed mtry should be in interval <1," + DRF.this._ncols + "> but it is " + DRF.this._mtry);
            }
            DRF.this._initialPrediction = DRF.this.isClassifier() ? 0.0 : DRF.this.getInitialValue();
            this.initTreeMeasurements();
            new MRTask(){

                public void map(Chunk[] chks) {
                    Chunk cy = DRF.this.chk_resp(chks);
                    for (int i = 0; i < cy._len; ++i) {
                        if (cy.isNA(i)) continue;
                        if (DRF.this.isClassifier()) {
                            int cls = (int)cy.at8(i);
                            DRF.this.chk_work(chks, cls).set(i, 1L);
                            continue;
                        }
                        float pred = (float)cy.atd(i);
                        DRF.this.chk_work(chks, 0).set(i, pred);
                    }
                }
            }.doAll(DRF.this._train);
        }

        @Override
        protected void buildNextKTrees() {
            DTree[] ktrees = new DTree[DRF.this._nclass];
            int[] leafs = new int[DRF.this._nclass];
            this.growTrees(ktrees, leafs, DRF.this._rand);
            CollectPreds cp = (CollectPreds)new CollectPreds(ktrees, leafs, ((DRFModel)DRF.this._model).defaultThreshold()).doAll(DRF.this._train, ((DRFModel.DRFParameters)DRF.this._parms)._build_tree_one_node);
            if (DRF.this.isClassifier()) {
                TreeMeasuresCollector.asVotes(this._treeMeasuresOnOOB).append(cp.rightVotes, cp.allRows);
            } else {
                TreeMeasuresCollector.asSSE(this._treeMeasuresOnOOB).append(cp.sse, cp.allRows);
            }
            ((DRFModel.DRFOutput)((DRFModel)((DRF)DRF.this)._model)._output).addKTrees(ktrees);
        }

        private void growTrees(DTree[] ktrees, int[] leafs, Random rand) {
            int k;
            DHistogram[][][] hcs = new DHistogram[DRF.this._nclass][1][DRF.this._ncols];
            int adj_nbins = Math.max(((DRFModel.DRFParameters)DRF.this._parms)._nbins_top_level, ((DRFModel.DRFParameters)DRF.this._parms)._nbins);
            long rseed = rand.nextLong();
            for (int k2 = 0; k2 < DRF.this._nclass; ++k2) {
                if (((DRFModel.DRFOutput)((DRFModel)((DRF)DRF.this)._model)._output)._distribution[k2] == 0.0 || k2 == 1 && DRF.this._nclass == 2 && ((DRFModel)DRF.this._model).binomialOpt()) continue;
                ktrees[k2] = new DTree(DRF.this._train, DRF.this._ncols, (char)((DRFModel.DRFParameters)DRF.this._parms)._nbins, (char)((DRFModel.DRFParameters)DRF.this._parms)._nbins_cats, (char)DRF.this._nclass, ((DRFModel.DRFParameters)DRF.this._parms)._min_rows, DRF.this._mtry, rseed);
                new DTree.UndecidedNode(ktrees[k2], -1, DHistogram.initialHist(DRF.this._train, DRF.this._ncols, adj_nbins, ((DRFModel.DRFParameters)DRF.this._parms)._nbins_cats, hcs[k2][0]));
            }
            Sample[] ss = new Sample[DRF.this._nclass];
            for (k = 0; k < DRF.this._nclass; ++k) {
                if (ktrees[k] == null) continue;
                ss[k] = (Sample)new Sample(ktrees[k], ((DRFModel.DRFParameters)DRF.this._parms)._sample_rate).dfork(null, new Frame(new Vec[]{DRF.this.vec_nids(DRF.this._train, k), DRF.this.vec_resp(DRF.this._train)}), ((DRFModel.DRFParameters)DRF.this._parms)._build_tree_one_node);
            }
            for (k = 0; k < DRF.this._nclass; ++k) {
                if (ss[k] == null) continue;
                ss[k].getResult();
            }
            for (int depth = 0; depth < ((DRFModel.DRFParameters)DRF.this._parms)._max_depth; ++depth) {
                if (!DRF.this.isRunning()) {
                    return;
                }
                hcs = DRF.this.buildLayer(DRF.this._train, ((DRFModel.DRFParameters)DRF.this._parms)._nbins, ((DRFModel.DRFParameters)DRF.this._parms)._nbins_cats, ktrees, leafs, hcs, DRF.this._mtry < ((DRFModel.DRFOutput)((DRFModel)((DRF)DRF.this)._model)._output).nfeatures(), ((DRFModel.DRFParameters)DRF.this._parms)._build_tree_one_node);
                if (hcs == null) break;
            }
            for (int k3 = 0; k3 < DRF.this._nclass; ++k3) {
                DTree tree = ktrees[k3];
                if (tree == null) continue;
                int leaf = leafs[k3] = tree.len();
                for (int nid = 0; nid < leaf; ++nid) {
                    if (!(tree.node(nid) instanceof DTree.DecidedNode)) continue;
                    DTree.DecidedNode dn = tree.decided(nid);
                    if (dn._split._col == -1) {
                        if (nid != 0) continue;
                        DTree.LeafNode ln = new DTree.LeafNode(tree, -1, 0);
                        ln._pred = (float)(DRF.this.isClassifier() ? ((DRFModel.DRFOutput)((DRFModel)((DRF)DRF.this)._model)._output)._priorClassDist[k3] : DRF.this._initialPrediction);
                        continue;
                    }
                    for (int i = 0; i < dn._nids.length; ++i) {
                        int cnid = dn._nids[i];
                        if (cnid != -1 && !(tree.node(cnid) instanceof DTree.UndecidedNode) && (!(tree.node(cnid) instanceof DTree.DecidedNode) || ((DTree.DecidedNode)tree.node((int)cnid))._split.col() != -1)) continue;
                        DTree.LeafNode ln = new DTree.LeafNode(tree, nid);
                        ln._pred = (float)dn.pred(i);
                        dn._nids[i] = ln.nid();
                    }
                }
            }
        }

        protected DRFModel makeModel(Key modelKey, DRFModel.DRFParameters parms, double mse_train, double mse_valid) {
            return new DRFModel(modelKey, parms, new DRFModel.DRFOutput(DRF.this, mse_train, mse_valid));
        }

        private class CollectPreds
        extends MRTask<CollectPreds> {
            final DTree[] _trees;
            double _threshold;
            long rightVotes;
            long allRows;
            float sse;
            final boolean importance = true;

            CollectPreds(DTree[] trees, int[] leafs, double threshold) {
                this._trees = trees;
                this._threshold = threshold;
            }

            public void map(Chunk[] chks) {
                Chunk y = DRF.this.chk_resp(chks);
                double[] rpred = new double[1 + DRF.this._nclass];
                double[] rowdata = new double[DRF.this._ncols];
                Chunk oobt = DRF.this.chk_oobt(chks);
                for (int row = 0; row < oobt._len; ++row) {
                    boolean wasOOBRow = ScoreBuildHistogram.isOOBRow((int)DRF.this.chk_nids(chks, 0).at8(row));
                    for (int k = 0; k < DRF.this._nclass; ++k) {
                        DTree tree = this._trees[k];
                        if (tree == null) continue;
                        Chunk nids = DRF.this.chk_nids(chks, k);
                        int nid = (int)nids.at8(row);
                        if (wasOOBRow) {
                            int leafnid;
                            Chunk ct = DRF.this.chk_tree(chks, k);
                            if (tree.node(nid = ScoreBuildHistogram.oob2Nid(nid)) instanceof DTree.UndecidedNode) {
                                nid = tree.node(nid).pid();
                            }
                            if (tree.root() instanceof DTree.LeafNode) {
                                leafnid = 0;
                            } else {
                                DTree.DecidedNode dn = tree.decided(nid);
                                if (dn._split.col() == -1) {
                                    dn = tree.decided(tree.node(nid).pid());
                                }
                                leafnid = dn.ns(chks, row);
                            }
                            double prediction = ((DTree.LeafNode)tree.node(leafnid)).pred();
                            rpred[1 + k] = (float)prediction;
                            ct.set(row, (float)(ct.atd(row) + prediction));
                        }
                        nids.set(row, 0L);
                    }
                    if (wasOOBRow) {
                        oobt.set(row, oobt.atd(row) + 1.0);
                    }
                    if (!wasOOBRow || y.isNA(row)) continue;
                    if (DRF.this.isClassifier()) {
                        int actuPred;
                        int treePred = GenModel.getPrediction((double[])rpred, (double[])((DRFModel.DRFOutput)((DRFModel)((DRF)DRF.this)._model)._output)._priorClassDist, (double[])DRF.this.data_row(chks, row, rowdata), (double)this._threshold);
                        if (treePred == (actuPred = (int)y.at8(row))) {
                            ++this.rightVotes;
                        }
                    } else {
                        double treePred = rpred[1];
                        double actuPred = y.atd(row);
                        this.sse = (float)((double)this.sse + (actuPred - treePred) * (actuPred - treePred));
                    }
                    ++this.allRows;
                }
            }

            public void reduce(CollectPreds mrt) {
                this.rightVotes += mrt.rightVotes;
                this.allRows += mrt.allRows;
                this.sse += mrt.sse;
            }
        }
    }
}

