/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.gbm;

import hex.Model;
import hex.VarImp;
import hex.schemas.GBMV2;
import hex.tree.DHistogram;
import hex.tree.DTree;
import hex.tree.SharedTree;
import hex.tree.gbm.GBMModel;
import hex.tree.gbm.ResidualsCollector;
import water.AutoBuffer;
import water.H2O;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.Timer;

public class GBM
extends SharedTree<GBMModel, GBMModel.GBMParameters, GBMModel.GBMOutput> {
    public Model.ModelCategory[] can_build() {
        return new Model.ModelCategory[]{Model.ModelCategory.Regression, Model.ModelCategory.Binomial, Model.ModelCategory.Multinomial};
    }

    public GBM(GBMModel.GBMParameters parms) {
        super("GBM", parms);
        this.init(false);
    }

    public GBMV2 schema() {
        return new GBMV2();
    }

    public Job<GBMModel> trainModel() {
        return this.start(new GBMDriver(), ((GBMModel.GBMParameters)this._parms)._ntrees);
    }

    public Vec vresponse() {
        return super.vresponse() == null ? this.response() : super.vresponse();
    }

    @Override
    public void init(boolean expensive) {
        super.init(expensive);
        if (!(0.0 < (double)((GBMModel.GBMParameters)this._parms)._learn_rate) || !((double)((GBMModel.GBMParameters)this._parms)._learn_rate <= 1.0)) {
            this.error("_learn_rate", "learn_rate must be between 0 and 1");
        }
        if (((GBMModel.GBMParameters)this._parms)._loss == GBMModel.GBMParameters.Family.bernoulli) {
            if (this._nclass != 2) {
                this.error("_loss", "Bernoulli requires the response to be a 2-class categorical");
            }
            double mean = this._response.mean();
            this._initialPrediction = Math.log(mean / (1.0 - mean));
        }
    }

    @Override
    protected DTree.DecidedNode makeDecided(DTree.UndecidedNode udn, DHistogram[] hs) {
        return new GBMDecidedNode(udn, hs);
    }

    @Override
    protected VarImp doVarImpCalc(boolean scale) {
        throw H2O.unimpl();
    }

    @Override
    protected float score1(Chunk[] chks, float[] fs, int row) {
        if (((GBMModel.GBMParameters)this._parms)._loss == GBMModel.GBMParameters.Family.bernoulli) {
            fs[1] = 1.0f / (float)(1.0 + Math.exp(this.chk_tree(chks, 0).at0(row)));
            fs[2] = 1.0f - fs[1];
            return fs[1] + fs[2];
        }
        if (this._nclass == 1) {
            fs[0] = (float)this.chk_tree(chks, 0).at0(row);
            return fs[0];
        }
        if (this._nclass == 2) {
            fs[1] = (float)Math.exp(this.chk_tree(chks, 0).at0(row));
            fs[2] = 1.0f / fs[1];
            return fs[1] + fs[2];
        }
        float sum = 0.0f;
        for (int k = 0; k < this._nclass; ++k) {
            float f = (float)Math.exp(this.chk_tree(chks, k).at0(row));
            fs[k + 1] = f;
            sum += f;
        }
        return sum;
    }

    static class GBMLeafNode
    extends DTree.LeafNode {
        GBMLeafNode(DTree tree, int pid) {
            super(tree, pid);
        }

        GBMLeafNode(DTree tree, int pid, int nid) {
            super(tree, pid, nid);
        }

        @Override
        protected AutoBuffer compress(AutoBuffer ab) {
            assert (!Double.isNaN(this._pred));
            return ab.put4f((float)this._pred);
        }

        @Override
        protected int size() {
            return 4;
        }
    }

    static class GBMUndecidedNode
    extends DTree.UndecidedNode {
        GBMUndecidedNode(DTree tree, int pid, DHistogram[] hs) {
            super(tree, pid, hs);
        }

        @Override
        public int[] scoreCols(DHistogram[] hs) {
            return null;
        }
    }

    static class GBMDecidedNode
    extends DTree.DecidedNode {
        GBMDecidedNode(DTree.UndecidedNode n, DHistogram[] hs) {
            super(n, hs);
        }

        @Override
        public DTree.UndecidedNode makeUndecidedNode(DHistogram[] hs) {
            return new GBMUndecidedNode(this._tree, this._nid, hs);
        }

        @Override
        public DTree.Split bestCol(DTree.UndecidedNode u, DHistogram[] hs) {
            DTree.Split best = new DTree.Split(-1, -1, null, 0, Double.MAX_VALUE, Double.MAX_VALUE, 0L, 0L, 0.0, 0.0);
            if (hs == null) {
                return best;
            }
            for (int i = 0; i < hs.length; ++i) {
                DTree.Split s;
                if (hs[i] == null || hs[i].nbins() <= 1 || (s = hs[i].scoreMSE(i)) == null) continue;
                if (best == null || s.se() < best.se()) {
                    best = s;
                }
                if (s.se() <= 0.0) break;
            }
            return best;
        }
    }

    private class GBMDriver
    extends SharedTree.Driver {
        private transient float[] _improvPerVar;

        private GBMDriver() {
        }

        @Override
        protected void buildModel() {
            if (((GBMModel.GBMParameters)GBM.this._parms)._importance) {
                this._improvPerVar = new float[GBM.this._nclass];
            }
            if (((GBMModel.GBMParameters)GBM.this._parms)._checkpoint) {
                Timer t = new Timer();
                new ResidualsCollector(GBM.this._ncols, GBM.this._nclass, ((GBMModel.GBMOutput)((GBMModel)((GBM)GBM.this)._model)._output)._treeKeys).doAll(GBM.this._train);
                Log.info((Object[])new Object[]{"Reconstructing tree residuals stats from checkpointed model took " + t});
            }
            for (int tid = 0; tid < ((GBMModel.GBMParameters)GBM.this._parms)._ntrees; ++tid) {
                if (tid != 0 || !((GBMModel.GBMParameters)GBM.this._parms)._checkpoint) {
                    GBM.this.doScoringAndSaveModel(false, false, false);
                }
                new ComputeProb().doAll(GBM.this._train);
                new ComputeRes().doAll(GBM.this._train);
                Timer kb_timer = new Timer();
                this.buildNextKTrees();
                Log.info((Object[])new Object[]{tid + 1 + ". tree was built in " + kb_timer.toString()});
                if (GBM.this.isRunning()) continue;
                return;
            }
            GBM.this.doScoringAndSaveModel(true, false, false);
        }

        private void buildNextKTrees() {
            int i;
            final DTree[] ktrees = new DTree[GBM.this._nclass];
            DHistogram[][][] hcs = new DHistogram[GBM.this._nclass][1][GBM.this._ncols];
            int top_level_extra_bins = 1024;
            int nbins = Math.max(1024, ((GBMModel.GBMParameters)GBM.this._parms)._nbins);
            for (int k = 0; k < GBM.this._nclass; ++k) {
                if (((GBMModel.GBMOutput)((GBMModel)((GBM)GBM.this)._model)._output)._distribution[k] == 0L || k == 1 && GBM.this._nclass == 2) continue;
                ktrees[k] = new DTree(((GBM)GBM.this)._train._names, GBM.this._ncols, (char)((GBMModel.GBMParameters)GBM.this._parms)._nbins, (char)GBM.this._nclass, ((GBMModel.GBMParameters)GBM.this._parms)._min_rows);
                new GBMUndecidedNode(ktrees[k], -1, DHistogram.initialHist(GBM.this._train, GBM.this._ncols, nbins, hcs[k][0], false, false));
            }
            int[] leafs = new int[GBM.this._nclass];
            for (int depth = 0; depth < ((GBMModel.GBMParameters)GBM.this._parms)._max_depth; ++depth) {
                if (!GBM.this.isRunning()) {
                    return;
                }
                hcs = GBM.this.buildLayer(GBM.this._train, nbins, ktrees, leafs, hcs, false, false);
                if (hcs == null) break;
            }
            for (int k = 0; k < GBM.this._nclass; ++k) {
                DTree tree = ktrees[k];
                if (tree == null) continue;
                int leaf = leafs[k] = tree.len();
                for (int nid = 0; nid < leaf; ++nid) {
                    if (!(tree.node(nid) instanceof DTree.DecidedNode)) continue;
                    DTree.DecidedNode dn = tree.decided(nid);
                    for (i = 0; i < dn._nids.length; ++i) {
                        int cnid = dn._nids[i];
                        if (cnid != -1 && !(tree.node(cnid) instanceof DTree.UndecidedNode) && (!(tree.node(cnid) instanceof DTree.DecidedNode) || ((DTree.DecidedNode)tree.node((int)cnid))._split.col() != -1)) continue;
                        dn._nids[i] = new GBMLeafNode(tree, nid).nid();
                    }
                    if (nid != 0 || dn._split.col() != -1) continue;
                    new GBMLeafNode(tree, -1, 0);
                }
            }
            GammaPass gp = (GammaPass)new GammaPass(ktrees, leafs, ((GBMModel.GBMParameters)GBM.this._parms)._loss == GBMModel.GBMParameters.Family.bernoulli).doAll(GBM.this._train);
            double m1class = GBM.this._nclass > 1 && ((GBMModel.GBMParameters)GBM.this._parms)._loss != GBMModel.GBMParameters.Family.bernoulli ? (double)(GBM.this._nclass - 1) / (double)GBM.this._nclass : 1.0;
            for (int k = 0; k < GBM.this._nclass; ++k) {
                DTree tree = ktrees[k];
                if (tree == null) continue;
                for (i = 0; i < tree._len - leafs[k]; ++i) {
                    double g;
                    double d = gp._gss[k][i] == 0.0 ? (double)(gp._rss[k][i] == 0.0 ? 0 : 1000) : (g = (double)((GBMModel.GBMParameters)GBM.this._parms)._learn_rate * m1class * gp._rss[k][i] / gp._gss[k][i]);
                    assert (!Double.isNaN(g));
                    ((DTree.LeafNode)tree.node((int)(leafs[k] + i)))._pred = g;
                }
            }
            new MRTask(){

                public void map(Chunk[] chks) {
                    for (int k = 0; k < GBM.this._nclass; ++k) {
                        DTree tree = ktrees[k];
                        if (tree == null) continue;
                        Chunk nids = GBM.this.chk_nids(chks, k);
                        Chunk ct = GBM.this.chk_tree(chks, k);
                        for (int row = 0; row < nids._len; ++row) {
                            int nid = (int)nids.at80(row);
                            if (nid < 0) continue;
                            ct.set0(row, (float)(ct.at0(row) + (double)((float)((DTree.LeafNode)tree.node((int)nid))._pred)));
                            nids.set0(row, 0L);
                        }
                    }
                }
            }.doAll(GBM.this._train);
            for (int i2 = 0; i2 < ktrees.length; ++i2) {
                if (ktrees[i2] == null) continue;
                ktrees[i2]._leaves = ktrees[i2].len() - leafs[i2];
            }
            ((GBMModel.GBMOutput)((GBMModel)((GBM)GBM.this)._model)._output).addKTrees(ktrees);
        }

        protected GBMModel makeModel(Key modelKey, GBMModel.GBMParameters parms, double mse_train, double mse_valid) {
            return new GBMModel(modelKey, parms, new GBMModel.GBMOutput(GBM.this, mse_train, mse_valid));
        }

        private class GammaPass
        extends MRTask<GammaPass> {
            final DTree[] _trees;
            final int[] _leafs;
            final boolean _isBernoulli;
            double[][] _rss;
            double[][] _gss;

            GammaPass(DTree[] trees, int[] leafs, boolean isBernoulli) {
                this._leafs = leafs;
                this._trees = trees;
                this._isBernoulli = isBernoulli;
            }

            public void map(Chunk[] chks) {
                this._gss = new double[GBM.this._nclass][];
                this._rss = new double[GBM.this._nclass][];
                Chunk resp = GBM.this.chk_resp(chks);
                for (int k = 0; k < GBM.this._nclass; ++k) {
                    DTree tree = this._trees[k];
                    int leaf = this._leafs[k];
                    if (tree == null) continue;
                    this._gss[k] = new double[tree._len - leaf];
                    double[] gs = this._gss[k];
                    this._rss[k] = new double[tree._len - leaf];
                    double[] rs = this._rss[k];
                    Chunk nids = GBM.this.chk_nids(chks, k);
                    Chunk ress = GBM.this.chk_work(chks, k);
                    if (tree.root() instanceof DTree.LeafNode) continue;
                    for (int row = 0; row < nids._len; ++row) {
                        int nid = (int)nids.at80(row);
                        if (nid < 0) continue;
                        if (tree.node(nid) instanceof DTree.UndecidedNode) {
                            nid = tree.node((int)nid)._pid;
                        }
                        DTree.DecidedNode dn = tree.decided(nid);
                        if (dn._split._col == -1) {
                            nid = dn._pid;
                            dn = tree.decided(nid);
                        }
                        int leafnid = dn.ns(chks, row);
                        assert (leaf <= leafnid && leafnid < tree._len);
                        assert (tree.node(leafnid) instanceof DTree.LeafNode);
                        nids.set0(row, (long)leafnid);
                        assert (!ress.isNA0(row));
                        double res = ress.at0(row);
                        double ares = Math.abs(res);
                        if (this._isBernoulli) {
                            double prob = resp.at0(row) - res;
                            int n = leafnid - leaf;
                            gs[n] = gs[n] + prob * (1.0 - prob);
                        } else {
                            int n = leafnid - leaf;
                            gs[n] = gs[n] + (GBM.this._nclass > 1 ? ares * (1.0 - ares) : 1.0);
                        }
                        int n = leafnid - leaf;
                        rs[n] = rs[n] + res;
                    }
                }
            }

            public void reduce(GammaPass gp) {
                ArrayUtils.add((double[][])this._gss, (double[][])gp._gss);
                ArrayUtils.add((double[][])this._rss, (double[][])gp._rss);
            }
        }

        class ComputeRes
        extends MRTask<ComputeRes> {
            ComputeRes() {
            }

            public void map(Chunk[] chks) {
                Chunk ys = GBM.this.chk_resp(chks);
                if (((GBMModel.GBMParameters)GBM.this._parms)._loss == GBMModel.GBMParameters.Family.bernoulli) {
                    for (int row = 0; row < ys._len; ++row) {
                        if (ys.isNA0(row)) continue;
                        int y = (int)ys.at80(row);
                        Chunk wk = GBM.this.chk_work(chks, 0);
                        wk.set0(row, (float)y - 1.0f + (float)wk.at0(row));
                    }
                } else if (GBM.this._nclass > 1) {
                    for (int row = 0; row < ys._len; ++row) {
                        if (ys.isNA0(row)) continue;
                        int y = (int)ys.at80(row);
                        for (int k = 0; k < GBM.this._nclass; ++k) {
                            if (((GBMModel.GBMOutput)((GBMModel)((GBM)GBM.this)._model)._output)._distribution[k] == 0L) continue;
                            Chunk wk = GBM.this.chk_work(chks, k);
                            wk.set0(row, (y == k ? 1.0f : 0.0f) - (float)wk.at0(row));
                        }
                    }
                } else {
                    Chunk wk = GBM.this.chk_work(chks, 0);
                    for (int row = 0; row < ys._len; ++row) {
                        wk.set0(row, (float)(ys.at0(row) - wk.at0(row)));
                    }
                }
            }
        }

        class ComputeProb
        extends MRTask<ComputeProb> {
            ComputeProb() {
            }

            public void map(Chunk[] chks) {
                Chunk ys = GBM.this.chk_resp(chks);
                if (((GBMModel.GBMParameters)GBM.this._parms)._loss == GBMModel.GBMParameters.Family.bernoulli) {
                    Chunk tr = GBM.this.chk_tree(chks, 0);
                    Chunk wk = GBM.this.chk_work(chks, 0);
                    for (int row = 0; row < ys._len; ++row) {
                        wk.set0(row, 1.0 / (1.0 + Math.exp(tr.at0(row))));
                    }
                } else if (GBM.this._nclass > 1) {
                    float[] fs = new float[GBM.this._nclass + 1];
                    for (int row = 0; row < ys._len; ++row) {
                        int k;
                        float sum = GBM.this.score1(chks, fs, row);
                        if (Float.isInfinite(sum)) {
                            for (k = 0; k < GBM.this._nclass; ++k) {
                                GBM.this.chk_work(chks, k).set0(row, Float.isInfinite(fs[k + 1]) ? 1.0f : 0.0f);
                            }
                            continue;
                        }
                        for (k = 0; k < GBM.this._nclass; ++k) {
                            GBM.this.chk_work(chks, k).set0(row, fs[k + 1] / sum);
                        }
                    }
                } else {
                    Chunk tr = GBM.this.chk_tree(chks, 0);
                    Chunk wk = GBM.this.chk_work(chks, 0);
                    for (int row = 0; row < ys._len; ++row) {
                        wk.set0(row, (float)tr.at0(row));
                    }
                }
            }
        }
    }
}

