/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.gbm;

import hex.Model;
import hex.ModelBuilder;
import hex.genmodel.GenModel;
import hex.schemas.GBMV3;
import hex.tree.DHistogram;
import hex.tree.DTree;
import hex.tree.SharedTree;
import hex.tree.gbm.GBMModel;
import hex.tree.gbm.ResidualsCollector;
import water.AutoBuffer;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.Timer;

public class GBM
extends SharedTree<GBMModel, GBMModel.GBMParameters, GBMModel.GBMOutput> {
    public Model.ModelCategory[] can_build() {
        return new Model.ModelCategory[]{Model.ModelCategory.Regression, Model.ModelCategory.Binomial, Model.ModelCategory.Multinomial};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    public GBM(GBMModel.GBMParameters parms) {
        super("GBM", parms);
        this.init(false);
    }

    public GBMV3 schema() {
        return new GBMV3();
    }

    public Job<GBMModel> trainModel() {
        return this.start(new GBMDriver(), ((GBMModel.GBMParameters)this._parms)._ntrees);
    }

    public Vec vresponse() {
        return super.vresponse() == null ? this.response() : super.vresponse();
    }

    @Override
    public void init(boolean expensive) {
        super.init(expensive);
        double mean = 0.0;
        if (expensive) {
            mean = this._response.mean();
            double d = this._nclass == 1 ? mean : (this._initialPrediction = this._nclass == 2 ? -0.5 * Math.log(mean / (1.0 - mean)) : 0.0);
            if (((GBMModel.GBMParameters)this._parms)._distribution == GBMModel.GBMParameters.Family.AUTO) {
                if (this._nclass == 1) {
                    ((GBMModel.GBMParameters)this._parms)._distribution = GBMModel.GBMParameters.Family.gaussian;
                }
                if (this._nclass == 2) {
                    ((GBMModel.GBMParameters)this._parms)._distribution = GBMModel.GBMParameters.Family.bernoulli;
                }
                if (this._nclass >= 3) {
                    ((GBMModel.GBMParameters)this._parms)._distribution = GBMModel.GBMParameters.Family.multinomial;
                }
            }
        }
        switch (((GBMModel.GBMParameters)this._parms)._distribution) {
            case bernoulli: {
                if (this._nclass != 2) {
                    this.error("_distribution", "Binomial requires the response to be a 2-class categorical");
                    break;
                }
                if (this._response == null) break;
                this._initialPrediction = Math.log(mean / (1.0 - mean));
                break;
            }
            case multinomial: {
                if (this.isClassifier()) break;
                this.error("_distribution", "Multinomial requires an enum response.");
                break;
            }
            case gaussian: {
                if (!this.isClassifier()) break;
                this.error("_distribution", "Gaussian requires the response to be numeric.");
                break;
            }
            case AUTO: {
                break;
            }
            default: {
                this.error("_distribution", "Invalid distribution: " + (Object)((Object)((GBMModel.GBMParameters)this._parms)._distribution));
            }
        }
        if (!(0.0 < (double)((GBMModel.GBMParameters)this._parms)._learn_rate) || !((double)((GBMModel.GBMParameters)this._parms)._learn_rate <= 1.0)) {
            this.error("_learn_rate", "learn_rate must be between 0 and 1");
        }
    }

    @Override
    protected DTree.DecidedNode makeDecided(DTree.UndecidedNode udn, DHistogram[] hs) {
        return new GBMDecidedNode(udn, hs);
    }

    @Override
    protected double score1(Chunk[] chks, double[] fs, int row) {
        if (((GBMModel.GBMParameters)this._parms)._distribution == GBMModel.GBMParameters.Family.bernoulli) {
            fs[1] = 1.0 / (1.0 + Math.exp(this.chk_tree(chks, 0).atd(row)));
            fs[2] = 1.0 - fs[1];
            return 1.0;
        }
        if (this._nclass == 1) {
            fs[0] = this.chk_tree(chks, 0).atd(row);
            return fs[0];
        }
        if (this._nclass == 2) {
            fs[1] = Math.exp(this.chk_tree(chks, 0).atd(row));
            fs[2] = 1.0 / fs[1];
            return fs[1] + fs[2];
        }
        for (int k = 0; k < this._nclass; ++k) {
            fs[k + 1] = this.chk_tree(chks, k).atd(row);
        }
        return GenModel.log_rescale((double[])fs);
    }

    static class GBMLeafNode
    extends DTree.LeafNode {
        GBMLeafNode(DTree tree, int pid) {
            super(tree, pid);
        }

        GBMLeafNode(DTree tree, int pid, int nid) {
            super(tree, pid, nid);
        }

        @Override
        protected AutoBuffer compress(AutoBuffer ab) {
            assert (!Double.isNaN(this._pred));
            return ab.put4f(this._pred);
        }

        @Override
        protected int size() {
            return 4;
        }
    }

    static class GBMUndecidedNode
    extends DTree.UndecidedNode {
        GBMUndecidedNode(DTree tree, int pid, DHistogram[] hs) {
            super(tree, pid, hs);
        }

        @Override
        public int[] scoreCols(DHistogram[] hs) {
            return null;
        }
    }

    static class GBMDecidedNode
    extends DTree.DecidedNode {
        GBMDecidedNode(DTree.UndecidedNode n, DHistogram[] hs) {
            super(n, hs);
        }

        @Override
        public DTree.UndecidedNode makeUndecidedNode(DHistogram[] hs) {
            return new GBMUndecidedNode(this._tree, this._nid, hs);
        }

        @Override
        public DTree.Split bestCol(DTree.UndecidedNode u, DHistogram[] hs) {
            DTree.Split best = new DTree.Split(-1, -1, null, 0, Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE, 0L, 0L, 0.0, 0.0);
            if (hs == null) {
                return best;
            }
            for (int i = 0; i < hs.length; ++i) {
                DTree.Split s;
                if (hs[i] == null || hs[i].nbins() <= 1 || (s = hs[i].scoreMSE(i, this._tree._min_rows)) == null) continue;
                if (s.se() < best.se()) {
                    best = s;
                }
                if (s.se() <= 0.0) break;
            }
            return best;
        }
    }

    private class GBMDriver
    extends SharedTree.Driver {
        private GBMDriver() {
        }

        @Override
        protected void buildModel() {
            final double init = GBM.this._initialPrediction;
            if (init != 0.0) {
                new MRTask(){

                    public void map(Chunk tree) {
                        for (int i = 0; i < tree._len; ++i) {
                            tree.set(i, init);
                        }
                    }
                }.doAll(new Vec[]{GBM.this.vec_tree(GBM.this._train, 0)});
            }
            if (((GBMModel.GBMParameters)GBM.this._parms)._checkpoint) {
                Timer t = new Timer();
                new ResidualsCollector(GBM.this._ncols, GBM.this._nclass, ((GBMModel.GBMOutput)((GBMModel)((GBM)GBM.this)._model)._output)._treeKeys).doAll(GBM.this._train);
                Log.info((Object[])new Object[]{"Reconstructing tree residuals stats from checkpointed model took " + t});
            }
            for (int tid = 0; tid < ((GBMModel.GBMParameters)GBM.this._parms)._ntrees; ++tid) {
                double training_r2;
                if ((tid != 0 || !((GBMModel.GBMParameters)GBM.this._parms)._checkpoint) && (training_r2 = GBM.this.doScoringAndSaveModel(false, false, false)) >= ((GBMModel.GBMParameters)GBM.this._parms)._r2_stopping) {
                    return;
                }
                new ComputeProb().doAll(GBM.this._train);
                new ComputeRes().doAll(GBM.this._train);
                Timer kb_timer = new Timer();
                this.buildNextKTrees();
                Log.info((Object[])new Object[]{tid + 1 + ". tree was built in " + kb_timer.toString()});
                GBM.this.update(1L);
                if (GBM.this.isRunning()) continue;
                return;
            }
            GBM.this.doScoringAndSaveModel(true, false, false);
        }

        private void buildNextKTrees() {
            int i;
            final DTree[] ktrees = new DTree[GBM.this._nclass];
            DHistogram[][][] hcs = new DHistogram[GBM.this._nclass][1][GBM.this._ncols];
            int top_level_extra_bins = 1024;
            int nbins = Math.max(1024, ((GBMModel.GBMParameters)GBM.this._parms)._nbins);
            for (int k = 0; k < GBM.this._nclass; ++k) {
                if (((GBMModel.GBMOutput)((GBMModel)((GBM)GBM.this)._model)._output)._distribution[k] == 0L || k == 1 && GBM.this._nclass == 2) continue;
                ktrees[k] = new DTree(((GBM)GBM.this)._train._names, GBM.this._ncols, (char)((GBMModel.GBMParameters)GBM.this._parms)._nbins, (char)GBM.this._nclass, ((GBMModel.GBMParameters)GBM.this._parms)._min_rows);
                new GBMUndecidedNode(ktrees[k], -1, DHistogram.initialHist(GBM.this._train, GBM.this._ncols, nbins, hcs[k][0], false));
            }
            int[] leafs = new int[GBM.this._nclass];
            for (int depth = 0; depth < ((GBMModel.GBMParameters)GBM.this._parms)._max_depth; ++depth) {
                if (!GBM.this.isRunning()) {
                    return;
                }
                hcs = GBM.this.buildLayer(GBM.this._train, nbins, ktrees, leafs, hcs, false, false);
                if (hcs == null) break;
            }
            for (int k = 0; k < GBM.this._nclass; ++k) {
                DTree tree = ktrees[k];
                if (tree == null) continue;
                int leaf = leafs[k] = tree.len();
                for (int nid = 0; nid < leaf; ++nid) {
                    if (!(tree.node(nid) instanceof DTree.DecidedNode)) continue;
                    DTree.DecidedNode dn = tree.decided(nid);
                    if (dn._split._col == -1) {
                        if (nid != 0) continue;
                        new GBMLeafNode(tree, -1, 0);
                        continue;
                    }
                    for (i = 0; i < dn._nids.length; ++i) {
                        int cnid = dn._nids[i];
                        if (cnid != -1 && !(tree.node(cnid) instanceof DTree.UndecidedNode) && (!(tree.node(cnid) instanceof DTree.DecidedNode) || ((DTree.DecidedNode)tree.node((int)cnid))._split.col() != -1)) continue;
                        dn._nids[i] = new GBMLeafNode(tree, nid).nid();
                    }
                }
            }
            GammaPass gp = (GammaPass)new GammaPass(ktrees, leafs, ((GBMModel.GBMParameters)GBM.this._parms)._distribution == GBMModel.GBMParameters.Family.bernoulli).doAll(GBM.this._train);
            double m1class = GBM.this._nclass > 1 && ((GBMModel.GBMParameters)GBM.this._parms)._distribution != GBMModel.GBMParameters.Family.bernoulli ? (double)(GBM.this._nclass - 1) / (double)GBM.this._nclass : 1.0;
            for (int k = 0; k < GBM.this._nclass; ++k) {
                DTree tree = ktrees[k];
                if (tree == null) continue;
                for (i = 0; i < tree._len - leafs[k]; ++i) {
                    float gf = (float)((double)((GBMModel.GBMParameters)GBM.this._parms)._learn_rate * m1class * gp._rss[k][i] / gp._gss[k][i]);
                    if (gp._gss[k][i] == 0.0) {
                        gf = (float)(Math.signum(gp._rss[k][i]) * 10000.0);
                    }
                    if (((GBMModel.GBMParameters)GBM.this._parms)._distribution == GBMModel.GBMParameters.Family.multinomial) {
                        if ((double)gf > 10000.0) {
                            gf = 10000.0f;
                        } else if ((double)gf < -10000.0) {
                            gf = -10000.0f;
                        }
                    }
                    assert (!Float.isNaN(gf) && !Float.isInfinite(gf));
                    ((DTree.LeafNode)tree.node((int)(leafs[k] + i)))._pred = gf;
                }
            }
            new MRTask(){

                public void map(Chunk[] chks) {
                    for (int k = 0; k < GBM.this._nclass; ++k) {
                        DTree tree = ktrees[k];
                        if (tree == null) continue;
                        Chunk nids = GBM.this.chk_nids(chks, k);
                        Chunk ct = GBM.this.chk_tree(chks, k);
                        for (int row = 0; row < nids._len; ++row) {
                            int nid = (int)nids.at8(row);
                            if (nid < 0) continue;
                            ct.set(row, (float)(ct.atd(row) + (double)((DTree.LeafNode)tree.node((int)nid))._pred));
                            nids.set(row, 0L);
                        }
                    }
                }
            }.doAll(GBM.this._train);
            ((GBMModel.GBMOutput)((GBMModel)((GBM)GBM.this)._model)._output).addKTrees(ktrees);
        }

        protected GBMModel makeModel(Key modelKey, GBMModel.GBMParameters parms, double mse_train, double mse_valid) {
            return new GBMModel(modelKey, parms, new GBMModel.GBMOutput(GBM.this, mse_train, mse_valid));
        }

        private class GammaPass
        extends MRTask<GammaPass> {
            final DTree[] _trees;
            final int[] _leafs;
            final boolean _isBernoulli;
            double[][] _rss;
            double[][] _gss;

            GammaPass(DTree[] trees, int[] leafs, boolean isBernoulli) {
                this._leafs = leafs;
                this._trees = trees;
                this._isBernoulli = isBernoulli;
            }

            public void map(Chunk[] chks) {
                this._gss = new double[GBM.this._nclass][];
                this._rss = new double[GBM.this._nclass][];
                Chunk resp = GBM.this.chk_resp(chks);
                for (int k = 0; k < GBM.this._nclass; ++k) {
                    DTree tree = this._trees[k];
                    int leaf = this._leafs[k];
                    if (tree == null) continue;
                    this._gss[k] = new double[tree._len - leaf];
                    double[] gs = this._gss[k];
                    this._rss[k] = new double[tree._len - leaf];
                    double[] rs = this._rss[k];
                    Chunk nids = GBM.this.chk_nids(chks, k);
                    Chunk ress = GBM.this.chk_work(chks, k);
                    if (tree.root() instanceof DTree.LeafNode) continue;
                    for (int row = 0; row < nids._len; ++row) {
                        int nid = (int)nids.at8(row);
                        if (nid < 0) continue;
                        if (tree.node(nid) instanceof DTree.UndecidedNode) {
                            nid = tree.node((int)nid)._pid;
                        }
                        DTree.DecidedNode dn = tree.decided(nid);
                        if (dn._split._col == -1) {
                            dn = tree.decided(dn._pid);
                        }
                        int leafnid = dn.ns(chks, row);
                        assert (leaf <= leafnid && leafnid < tree._len) : "leaf: " + leaf + " leafnid: " + leafnid + " tree._len: " + tree._len + "\ndn: " + (Object)((Object)dn);
                        assert (tree.node(leafnid) instanceof DTree.LeafNode);
                        nids.set(row, (long)leafnid);
                        assert (!ress.isNA(row));
                        double res = ress.atd(row);
                        double ares = Math.abs(res);
                        if (this._isBernoulli) {
                            double prob = resp.atd(row) - res;
                            int n = leafnid - leaf;
                            gs[n] = gs[n] + prob * (1.0 - prob);
                        } else {
                            int n = leafnid - leaf;
                            gs[n] = gs[n] + (GBM.this._nclass > 1 ? ares * (1.0 - ares) : 1.0);
                        }
                        int n = leafnid - leaf;
                        rs[n] = rs[n] + res;
                    }
                }
            }

            public void reduce(GammaPass gp) {
                ArrayUtils.add((double[][])this._gss, (double[][])gp._gss);
                ArrayUtils.add((double[][])this._rss, (double[][])gp._rss);
            }
        }

        class ComputeRes
        extends MRTask<ComputeRes> {
            ComputeRes() {
            }

            public void map(Chunk[] chks) {
                Chunk ys = GBM.this.chk_resp(chks);
                if (((GBMModel.GBMParameters)GBM.this._parms)._distribution == GBMModel.GBMParameters.Family.bernoulli) {
                    for (int row = 0; row < ys._len; ++row) {
                        if (ys.isNA(row)) continue;
                        int y = (int)ys.at8(row);
                        Chunk wk = GBM.this.chk_work(chks, 0);
                        wk.set(row, (float)y - 1.0f + (float)wk.atd(row));
                    }
                } else if (GBM.this._nclass > 1) {
                    for (int row = 0; row < ys._len; ++row) {
                        if (ys.isNA(row)) continue;
                        int y = (int)ys.at8(row);
                        for (int k = 0; k < GBM.this._nclass; ++k) {
                            if (((GBMModel.GBMOutput)((GBMModel)((GBM)GBM.this)._model)._output)._distribution[k] == 0L) continue;
                            Chunk wk = GBM.this.chk_work(chks, k);
                            wk.set(row, (y == k ? 1.0f : 0.0f) - (float)wk.atd(row));
                        }
                    }
                } else {
                    Chunk wk = GBM.this.chk_work(chks, 0);
                    for (int row = 0; row < ys._len; ++row) {
                        wk.set(row, (float)(ys.atd(row) - wk.atd(row)));
                    }
                }
            }
        }

        class ComputeProb
        extends MRTask<ComputeProb> {
            ComputeProb() {
            }

            public void map(Chunk[] chks) {
                Chunk ys = GBM.this.chk_resp(chks);
                if (((GBMModel.GBMParameters)GBM.this._parms)._distribution == GBMModel.GBMParameters.Family.bernoulli) {
                    Chunk tr = GBM.this.chk_tree(chks, 0);
                    Chunk wk = GBM.this.chk_work(chks, 0);
                    for (int row = 0; row < ys._len; ++row) {
                        wk.set(row, 1.0 / (1.0 + Math.exp(tr.atd(row))));
                    }
                } else if (GBM.this._nclass > 1) {
                    double[] fs = new double[GBM.this._nclass + 1];
                    for (int row = 0; row < ys._len; ++row) {
                        int k;
                        double sum = GBM.this.score1(chks, fs, row);
                        if (Double.isInfinite(sum)) {
                            for (k = 0; k < GBM.this._nclass; ++k) {
                                GBM.this.chk_work(chks, k).set(row, Double.isInfinite(fs[k + 1]) ? 1.0f : 0.0f);
                            }
                            continue;
                        }
                        for (k = 0; k < GBM.this._nclass; ++k) {
                            GBM.this.chk_work(chks, k).set(row, (float)(fs[k + 1] / sum));
                        }
                    }
                } else {
                    Chunk tr = GBM.this.chk_tree(chks, 0);
                    Chunk wk = GBM.this.chk_work(chks, 0);
                    for (int row = 0; row < ys._len; ++row) {
                        wk.set(row, (float)tr.atd(row));
                    }
                }
            }
        }
    }
}

