/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.gbm;

import hex.Distribution;
import hex.KeyValue;
import hex.Model;
import hex.genmodel.GenModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.CompressedTree;
import hex.tree.Constraint;
import hex.tree.Constraints;
import hex.tree.Score;
import hex.tree.SharedTreeModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GbmMojoWriter;
import java.util.Arrays;
import water.DKV;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.SBPrintStream;

public class GBMModel
extends SharedTreeModel<GBMModel, GBMParameters, GBMOutput>
implements Model.StagedPredictions {
    public GBMModel(Key<GBMModel> selfKey, GBMParameters parms, GBMOutput output) {
        super(selfKey, parms, output);
    }

    public Frame scoreStagedPredictions(Frame frame, Key<Frame> destination_key) {
        Frame adaptFrm = new Frame(frame);
        this.adaptTestForTrain(adaptFrm, true, false);
        String[] names = this.makeAllTreeColumnNames();
        int outputcols = names.length;
        return ((StagedPredictionsTask)new StagedPredictionsTask(this).doAll(outputcols, (byte)3, adaptFrm)).outputFrame(destination_key, names, null);
    }

    @Override
    protected final double[] score0Incremental(Score.ScoreIncInfo sii, Chunk[] chks, double offset, int row_in_chunk, double[] tmp, double[] preds) {
        int i;
        assert (((GBMOutput)this._output).nfeatures() == tmp.length);
        for (i = 0; i < tmp.length; ++i) {
            tmp[i] = chks[i].atd(row_in_chunk);
        }
        if (sii._startTree == 0) {
            Arrays.fill(preds, 0.0);
        } else {
            for (i = 0; i < sii._workspaceColCnt; ++i) {
                preds[sii._predsAryOffset + i] = chks[sii._workspaceColIdx + i].atd(row_in_chunk);
            }
        }
        this.score0(tmp, preds, offset, sii._startTree, ((GBMOutput)this._output)._treeKeys.length);
        for (i = 0; i < sii._workspaceColCnt; ++i) {
            chks[sii._workspaceColIdx + i].set(row_in_chunk, preds[sii._predsAryOffset + i]);
        }
        this.score0Probabilities(preds, offset);
        this.score0PostProcessSupervised(preds, tmp);
        return preds;
    }

    @Override
    protected double[] score0(double[] data, double[] preds, double offset, int ntrees) {
        super.score0(data, preds, offset, ntrees);
        return this.score0Probabilities(preds, offset);
    }

    private double[] score0Probabilities(double[] preds, double offset) {
        if (((GBMParameters)this._parms)._distribution == DistributionFamily.bernoulli || ((GBMParameters)this._parms)._distribution == DistributionFamily.quasibinomial || ((GBMParameters)this._parms)._distribution == DistributionFamily.modified_huber) {
            double f = preds[1] + ((GBMOutput)this._output)._init_f + offset;
            preds[2] = new Distribution(this._parms).linkInv(f);
            preds[1] = 1.0 - preds[2];
        } else if (((GBMParameters)this._parms)._distribution == DistributionFamily.multinomial) {
            if (((GBMOutput)this._output).nclasses() == 2) {
                preds[1] = preds[1] + (((GBMOutput)this._output)._init_f + offset);
                preds[2] = -preds[1];
            }
            GenModel.GBM_rescale((double[])preds);
        } else {
            double f = preds[0] + ((GBMOutput)this._output)._init_f + offset;
            preds[0] = new Distribution(this._parms).linkInv(f);
        }
        return preds;
    }

    @Override
    protected void toJavaUnifyPreds(SBPrintStream body) {
        if (((GBMParameters)this._parms)._distribution == DistributionFamily.bernoulli || ((GBMParameters)this._parms)._distribution == DistributionFamily.quasibinomial || ((GBMParameters)this._parms)._distribution == DistributionFamily.modified_huber) {
            body.ip("preds[2] = preds[1] + ").p(((GBMOutput)this._output)._init_f).p(";").nl();
            body.ip("preds[2] = " + new Distribution(this._parms).linkInvString("preds[2]") + ";").nl();
            body.ip("preds[1] = 1.0-preds[2];").nl();
            if (((GBMParameters)this._parms)._balance_classes) {
                body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
            }
            body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + this.defaultThreshold() + ");").nl();
            return;
        }
        if (((GBMOutput)this._output).nclasses() == 1) {
            body.ip("preds[0] += ").p(((GBMOutput)this._output)._init_f).p(";").nl();
            body.ip("preds[0] = " + new Distribution(this._parms).linkInvString("preds[0]") + ";").nl();
            return;
        }
        if (((GBMOutput)this._output).nclasses() == 2) {
            body.ip("preds[1] += ").p(((GBMOutput)this._output)._init_f).p(";").nl();
            body.ip("preds[2] = - preds[1];").nl();
        }
        body.ip("hex.genmodel.GenModel.GBM_rescale(preds);").nl();
        if (((GBMParameters)this._parms)._balance_classes) {
            body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
        }
        body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + this.defaultThreshold() + ");").nl();
    }

    public GbmMojoWriter getMojo() {
        return new GbmMojoWriter(this);
    }

    private static class StagedPredictionsTask
    extends MRTask<StagedPredictionsTask> {
        private final Key<GBMModel> _modelKey;
        private transient GBMModel _model;

        private StagedPredictionsTask(GBMModel model) {
            this._modelKey = model._key;
        }

        protected void setupLocal() {
            this._model = (GBMModel)this._modelKey.get();
            assert (this._model != null);
        }

        public void map(Chunk[] chks, NewChunk[] nc) {
            double[] input = new double[chks.length];
            int contribOffset = ((GBMOutput)this._model._output).nclasses() == 1 ? 0 : 1;
            for (int row = 0; row < chks[0]._len; ++row) {
                for (int i = 0; i < chks.length; ++i) {
                    input[i] = chks[i].atd(row);
                }
                double[] contribs = new double[contribOffset + ((GBMOutput)this._model._output).nclasses()];
                double[] preds = new double[contribs.length];
                int col = 0;
                for (int tidx = 0; tidx < ((GBMOutput)this._model._output)._treeKeys.length; ++tidx) {
                    int i;
                    Key[] keys = ((GBMOutput)this._model._output)._treeKeys[tidx];
                    for (i = 0; i < keys.length; ++i) {
                        if (keys[i] != null) {
                            int n = contribOffset + i;
                            contribs[n] = contribs[n] + ((CompressedTree)DKV.get((Key)keys[i]).get()).score(input, ((GBMOutput)this._model._output)._domains);
                        }
                        preds[contribOffset + i] = contribs[contribOffset + i];
                    }
                    this._model.score0Probabilities(preds, 0.0);
                    this._model.score0PostProcessSupervised(preds, input);
                    for (i = 0; i < keys.length; ++i) {
                        if (keys[i] == null) continue;
                        nc[col++].addNum(preds[contribOffset + i]);
                    }
                }
                assert (col == nc.length);
            }
        }
    }

    public static class GBMOutput
    extends SharedTreeModel.SharedTreeOutput {
        boolean _quasibinomial;
        int _nclasses;

        public int nclasses() {
            return this._nclasses;
        }

        public GBMOutput(GBM b) {
            super(b);
            this._quasibinomial = ((GBMParameters)b._parms)._distribution == DistributionFamily.quasibinomial;
            this._nclasses = b.nclasses();
        }

        public String[] classNames() {
            String[] res = super.classNames();
            if (res == null && this._quasibinomial) {
                return new String[]{"0", "1"};
            }
            return res;
        }
    }

    public static class GBMParameters
    extends SharedTreeModel.SharedTreeParameters {
        public double _learn_rate = 0.1;
        public double _learn_rate_annealing = 1.0;
        public double _col_sample_rate = 1.0;
        public double _max_abs_leafnode_pred;
        public double _pred_noise_bandwidth;
        public KeyValue[] _monotone_constraints;

        public GBMParameters() {
            this._sample_rate = 1.0;
            this._ntrees = 50;
            this._max_depth = 5;
            this._max_abs_leafnode_pred = Double.MAX_VALUE;
            this._pred_noise_bandwidth = 0.0;
        }

        public String algoName() {
            return "GBM";
        }

        public String fullName() {
            return "Gradient Boosting Machine";
        }

        public String javaName() {
            return GBMModel.class.getName();
        }

        Constraints constraints(Frame f) {
            if (this._monotone_constraints == null || this._monotone_constraints.length == 0) {
                return null;
            }
            Constraint[] cs = new Constraint[f.numCols()];
            for (KeyValue spec : this._monotone_constraints) {
                if (spec.getValue() == 0.0) continue;
                int col = f.find(spec.getKey());
                if (col < 0) {
                    throw new IllegalStateException("Invalid constraint specification, column '" + spec.getKey() + "' doesn't exist.");
                }
                int direction = spec.getValue() < 0.0 ? -1 : 1;
                cs[col] = new Constraint(direction);
            }
            return new Constraints(cs);
        }
    }
}

