/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.drf;

import hex.genmodel.GenModel;
import hex.tree.SharedTreeModel;
import hex.tree.drf.DRF;
import water.Key;
import water.fvec.Chunk;
import water.util.MathUtils;
import water.util.SB;

public class DRFModel
extends SharedTreeModel<DRFModel, DRFParameters, DRFOutput> {
    public DRFModel(Key selfKey, DRFParameters parms, DRFOutput output) {
        super(selfKey, parms, output);
    }

    public double[] score0(Chunk[] chks, int row_in_chunk, double[] tmp, double[] preds) {
        assert (chks.length >= tmp.length);
        for (int i = 0; i < tmp.length; ++i) {
            tmp[i] = chks[i].atd(row_in_chunk);
        }
        return this.score0(tmp, preds);
    }

    @Override
    protected double[] score0(double[] data, double[] preds) {
        super.score0(data, preds);
        int N = ((DRFParameters)this._parms)._ntrees;
        if (((DRFOutput)this._output).nclasses() == 1) {
            preds[0] = preds[0] / (double)N;
            return preds;
        }
        if (((DRFOutput)this._output).nclasses() == 2) {
            preds[1] = preds[1] / (double)N;
            preds[2] = 1.0 - preds[1];
        } else {
            double sum = MathUtils.sum((double[])preds);
            if (sum > 0.0) {
                MathUtils.div((double[])preds, (double)sum);
            }
        }
        if (((DRFParameters)this._parms)._balance_classes) {
            GenModel.correctProbabilities((double[])preds, (double[])((DRFOutput)this._output)._priorClassDist, (double[])((DRFOutput)this._output)._modelClassDist);
        }
        preds[0] = GenModel.getPrediction((double[])preds, (double[])data, (double)this.defaultThreshold());
        return preds;
    }

    @Override
    protected void toJavaUnifyPreds(SB body, SB file) {
        if (((DRFOutput)this._output).nclasses() == 1) {
            body.ip("preds[0] /= " + ((DRFOutput)this._output)._ntrees + ";").nl();
        } else {
            if (((DRFOutput)this._output).nclasses() == 2) {
                body.ip("preds[1] /= " + ((DRFOutput)this._output)._ntrees + ";").nl();
                body.ip("preds[2] = 1.0 - preds[1];").nl();
            } else {
                body.ip("double sum = 0;").nl();
                body.ip("for(int i=1; i<preds.length; i++) { sum += preds[i]; }").nl();
                body.ip("if (sum>0) for(int i=1; i<preds.length; i++) { preds[i] /= sum; }").nl();
            }
            if (((DRFParameters)this._parms)._balance_classes) {
                body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
            }
            body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, data, " + this.defaultThreshold() + " );").nl();
        }
    }

    public static class DRFOutput
    extends SharedTreeModel.SharedTreeOutput {
        public DRFOutput(DRF b, double mse_train, double mse_valid) {
            super(b, mse_train, mse_valid);
        }
    }

    public static class DRFParameters
    extends SharedTreeModel.SharedTreeParameters {
        int _mtries = -1;
        float _sample_rate = 0.632f;
        public boolean _build_tree_one_node = false;

        public DRFParameters() {
            this._ntrees = 50;
            this._max_depth = 20;
            this._min_rows = 10;
        }
    }
}

