/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.gbm;

import hex.Distribution;
import hex.genmodel.GenModel;
import hex.tree.SharedTreeModel;
import hex.tree.gbm.GBM;
import water.Key;
import water.util.SBPrintStream;

public class GBMModel
extends SharedTreeModel<GBMModel, GBMParameters, GBMOutput> {
    public GBMModel(Key selfKey, GBMParameters parms, GBMOutput output) {
        super(selfKey, parms, output);
    }

    @Override
    protected double[] score0(double[] data, double[] preds, double weight, double offset) {
        super.score0(data, preds, weight, offset);
        if (((GBMParameters)this._parms)._distribution == Distribution.Family.bernoulli) {
            double f = preds[1] + ((GBMOutput)this._output)._init_f + offset;
            preds[2] = new Distribution(this._parms).linkInv(f);
            preds[1] = 1.0 - preds[2];
        } else if (((GBMParameters)this._parms)._distribution == Distribution.Family.multinomial) {
            if (((GBMOutput)this._output).nclasses() == 2) {
                preds[1] = preds[1] + (((GBMOutput)this._output)._init_f + offset);
                preds[2] = -preds[1];
            }
            GenModel.GBM_rescale((double[])preds);
        } else {
            double f = preds[0] + ((GBMOutput)this._output)._init_f + offset;
            preds[0] = new Distribution(this._parms).linkInv(f);
        }
        return preds;
    }

    @Override
    protected void toJavaUnifyPreds(SBPrintStream body) {
        if (((GBMParameters)this._parms)._distribution == Distribution.Family.bernoulli) {
            body.ip("preds[2] = preds[1] + ").p(((GBMOutput)this._output)._init_f).p(";").nl();
            body.ip("preds[2] = " + new Distribution(this._parms).linkInvString("preds[2]") + ";").nl();
            body.ip("preds[1] = 1.0-preds[2];").nl();
            if (((GBMParameters)this._parms)._balance_classes) {
                body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
            }
            body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + this.defaultThreshold() + ");").nl();
            return;
        }
        if (((GBMOutput)this._output).nclasses() == 1) {
            body.ip("preds[0] += ").p(((GBMOutput)this._output)._init_f).p(";").nl();
            body.ip("preds[0] = " + new Distribution(this._parms).linkInvString("preds[0]") + ";").nl();
            return;
        }
        if (((GBMOutput)this._output).nclasses() == 2) {
            body.ip("preds[1] += ").p(((GBMOutput)this._output)._init_f).p(";").nl();
            body.ip("preds[2] = - preds[1];").nl();
        }
        body.ip("hex.genmodel.GenModel.GBM_rescale(preds);").nl();
        if (((GBMParameters)this._parms)._balance_classes) {
            body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
        }
        body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + this.defaultThreshold() + ");").nl();
    }

    public static class GBMOutput
    extends SharedTreeModel.SharedTreeOutput {
        public GBMOutput(GBM b, double mse_train, double mse_valid) {
            super(b, mse_train, mse_valid);
        }
    }

    public static class GBMParameters
    extends SharedTreeModel.SharedTreeParameters {
        public float _learn_rate = 0.1f;
        public float _col_sample_rate = 1.0f;

        public GBMParameters() {
            this._sample_rate = 1.0f;
            this._ntrees = 50;
            this._max_depth = 5;
        }

        public String algoName() {
            return "GBM";
        }

        public String fullName() {
            return "Gradient Boosting Method";
        }

        public String javaName() {
            return GBMModel.class.getName();
        }
    }
}

