/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.gbm;

import hex.Distribution;
import hex.DistributionFactory;
import hex.LinkFunction;
import hex.Model;
import hex.genmodel.CategoricalEncoding;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.CompressedTree;
import hex.tree.SharedTreePojoWriter;
import hex.tree.gbm.GBMModel;
import water.util.SBPrintStream;

class GbmPojoWriter
extends SharedTreePojoWriter {
    private final double _init_f;
    private final boolean _balance_classes;
    private final DistributionFamily _distribution_family;
    private final LinkFunction _link_function;

    GbmPojoWriter(GBMModel model, CompressedTree[][] trees) {
        super(model._key, model._output, model.getGenModelEncoding(), model.binomialOpt(), trees, ((GBMModel.GBMOutput)model._output)._treeStats);
        this._init_f = ((GBMModel.GBMOutput)model._output)._init_f;
        this._balance_classes = ((GBMModel.GBMParameters)model._parms)._balance_classes;
        Distribution distribution = DistributionFactory.getDistribution((Model.Parameters)model._parms);
        this._distribution_family = distribution._family;
        this._link_function = distribution._linkFunction;
    }

    GbmPojoWriter(Model<?, ?, ?> model, CategoricalEncoding encoding, boolean binomialOpt, CompressedTree[][] trees, double initF, boolean balanceClasses, DistributionFamily distributionFamily, LinkFunction linkFunction) {
        super(model._key, model._output, encoding, binomialOpt, trees, null);
        this._init_f = initF;
        this._balance_classes = balanceClasses;
        this._distribution_family = distributionFamily;
        this._link_function = linkFunction;
    }

    @Override
    protected void toJavaUnifyPreds(SBPrintStream body) {
        if (this._distribution_family == DistributionFamily.bernoulli || this._distribution_family == DistributionFamily.quasibinomial || this._distribution_family == DistributionFamily.modified_huber) {
            body.ip("preds[2] = preds[1] + ").p(this._init_f).p(";").nl();
            body.ip("preds[2] = " + this._link_function.linkInvString("preds[2]") + ";").nl();
            body.ip("preds[1] = 1.0-preds[2];").nl();
            if (this._balance_classes) {
                body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
            }
            body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + this._output.defaultThreshold() + ");").nl();
            return;
        }
        if (this._output.nclasses() == 1) {
            body.ip("preds[0] += ").p(this._init_f).p(";").nl();
            body.ip("preds[0] = " + this._link_function.linkInvString("preds[0]") + ";").nl();
            return;
        }
        if (this._output.nclasses() == 2) {
            body.ip("preds[1] += ").p(this._init_f).p(";").nl();
            body.ip("preds[2] = - preds[1];").nl();
        }
        body.ip("hex.genmodel.GenModel.GBM_rescale(preds);").nl();
        if (this._balance_classes) {
            body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
        }
        body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + this._output.defaultThreshold() + ");").nl();
    }
}

