/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.drf;

import hex.tree.SharedTreeModel;
import hex.tree.drf.DRF;
import water.Key;
import water.util.MathUtils;
import water.util.SBPrintStream;

public class DRFModel
extends SharedTreeModel<DRFModel, DRFParameters, DRFOutput> {
    public DRFModel(Key selfKey, DRFParameters parms, DRFOutput output) {
        super(selfKey, parms, output);
    }

    @Override
    protected boolean binomialOpt() {
        return !((DRFParameters)this._parms)._binomial_double_trees;
    }

    @Override
    protected double[] score0(double[] data, double[] preds, double weight, double offset) {
        super.score0(data, preds, weight, offset);
        int N = ((DRFOutput)this._output)._ntrees;
        if (((DRFOutput)this._output).nclasses() == 1) {
            if (N >= 1) {
                preds[0] = preds[0] / (double)N;
            }
        } else if (((DRFOutput)this._output).nclasses() == 2 && this.binomialOpt()) {
            if (N >= 1) {
                preds[1] = preds[1] / (double)N;
            }
            preds[2] = 1.0 - preds[1];
        } else {
            double sum = MathUtils.sum((double[])preds);
            if (sum > 0.0) {
                MathUtils.div((double[])preds, (double)sum);
            }
        }
        return preds;
    }

    @Override
    protected void toJavaUnifyPreds(SBPrintStream body) {
        if (((DRFOutput)this._output).nclasses() == 1) {
            body.ip("preds[0] /= " + ((DRFOutput)this._output)._ntrees + ";").nl();
        } else {
            if (((DRFOutput)this._output).nclasses() == 2 && this.binomialOpt()) {
                body.ip("preds[1] /= " + ((DRFOutput)this._output)._ntrees + ";").nl();
                body.ip("preds[2] = 1.0 - preds[1];").nl();
            } else {
                body.ip("double sum = 0;").nl();
                body.ip("for(int i=1; i<preds.length; i++) { sum += preds[i]; }").nl();
                body.ip("if (sum>0) for(int i=1; i<preds.length; i++) { preds[i] /= sum; }").nl();
            }
            if (((DRFParameters)this._parms)._balance_classes) {
                body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
            }
            body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + this.defaultThreshold() + " );").nl();
        }
    }

    public static class DRFOutput
    extends SharedTreeModel.SharedTreeOutput {
        public DRFOutput(DRF b, double mse_train, double mse_valid) {
            super(b, mse_train, mse_valid);
        }
    }

    public static class DRFParameters
    extends SharedTreeModel.SharedTreeParameters {
        public boolean _binomial_double_trees = false;
        public int _mtries = -1;

        public DRFParameters() {
            this._sample_rate = 0.632f;
            this._max_depth = 20;
            this._min_rows = 1.0;
        }
    }
}

