/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.gbm;

import hex.genmodel.GenModel;
import hex.tree.SharedTreeModel;
import hex.tree.gbm.GBM;
import water.Key;
import water.util.SB;

public class GBMModel
extends SharedTreeModel<GBMModel, GBMParameters, GBMOutput> {
    public GBMModel(Key selfKey, GBMParameters parms, GBMOutput output) {
        super(selfKey, parms, output);
    }

    @Override
    protected double[] score0(double[] data, double[] preds, double weight, double offset) {
        super.score0(data, preds, weight, offset);
        if (((GBMParameters)this._parms)._distribution == GBMParameters.Family.bernoulli) {
            double fx = preds[1] + ((GBMOutput)this._output)._init_f + offset;
            preds[2] = 1.0 / (1.0 + Math.exp(-fx));
            preds[1] = 1.0 - preds[2];
            if (((GBMParameters)this._parms)._balance_classes) {
                GenModel.correctProbabilities((double[])preds, (double[])((GBMOutput)this._output)._priorClassDist, (double[])((GBMOutput)this._output)._modelClassDist);
            }
            preds[0] = GenModel.getPrediction((double[])preds, (double[])data, (double)this.defaultThreshold());
            return preds;
        }
        if (((GBMOutput)this._output).nclasses() == 1) {
            preds[0] = preds[0] + (((GBMOutput)this._output)._init_f + offset);
            return preds;
        }
        if (((GBMOutput)this._output).nclasses() == 2) {
            preds[1] = preds[1] + ((GBMOutput)this._output)._init_f;
            preds[2] = -preds[1];
        }
        GenModel.GBM_rescale((double[])preds);
        if (((GBMParameters)this._parms)._balance_classes) {
            GenModel.correctProbabilities((double[])preds, (double[])((GBMOutput)this._output)._priorClassDist, (double[])((GBMOutput)this._output)._modelClassDist);
        }
        preds[0] = GenModel.getPrediction((double[])preds, (double[])data, (double)this.defaultThreshold());
        return preds;
    }

    @Override
    protected void toJavaUnifyPreds(SB body, SB file) {
        if (((GBMParameters)this._parms)._distribution == GBMParameters.Family.bernoulli) {
            body.ip("double fx = preds[1] + ").p(((GBMOutput)this._output)._init_f).p(";").nl();
            body.ip("preds[2] = 1.0/(1.0+Math.exp(-fx));").nl();
            body.ip("preds[1] = 1.0-preds[2];").nl();
            if (((GBMParameters)this._parms)._balance_classes) {
                body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
            }
            body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, data, " + this.defaultThreshold() + ");").nl();
            return;
        }
        if (((GBMOutput)this._output).nclasses() == 1) {
            body.ip("preds[0] += ").p(((GBMOutput)this._output)._init_f).p(";");
            return;
        }
        if (((GBMOutput)this._output).nclasses() == 2) {
            body.ip("preds[1] += ").p(((GBMOutput)this._output)._init_f).p(";").nl();
            body.ip("preds[2] = - preds[1];").nl();
        }
        body.ip("hex.genmodel.GenModel.GBM_rescale(preds);").nl();
        if (((GBMParameters)this._parms)._balance_classes) {
            body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
        }
        body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, data, " + this.defaultThreshold() + ");").nl();
    }

    public static class GBMOutput
    extends SharedTreeModel.SharedTreeOutput {
        public GBMOutput(GBM b, double mse_train, double mse_valid) {
            super(b, mse_train, mse_valid);
        }
    }

    public static class GBMParameters
    extends SharedTreeModel.SharedTreeParameters {
        public Family _distribution = Family.AUTO;
        public float _learn_rate = 0.1f;

        public static enum Family {
            AUTO,
            bernoulli,
            multinomial,
            gaussian,
            poisson;

        }
    }
}

