/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.drf;

import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.PredictContributions;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.SharedTreeMojoModelWithContributions;
import hex.genmodel.algos.tree.TreeSHAPPredictor;

public final class DrfMojoModel
extends SharedTreeMojoModelWithContributions
implements SharedTreeGraphConverter {
    protected boolean _binomial_double_trees;

    public DrfMojoModel(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    @Override
    protected final PredictContributions getContributionsPredictor(TreeSHAPPredictor<double[]> treeSHAPPredictor) {
        return new ContributionsPredictorDRF(this, treeSHAPPredictor);
    }

    @Override
    public final double[] score0(double[] row, double offset, double[] preds) {
        super.scoreAllTrees(row, preds);
        return this.unifyPreds(row, offset, preds);
    }

    @Override
    public final double[] unifyPreds(double[] row, double offset, double[] preds) {
        if (this._nclasses == 1) {
            preds[0] = preds[0] / (double)this._ntree_groups;
        } else {
            if (this._nclasses == 2 && !this._binomial_double_trees) {
                preds[1] = preds[1] / (double)this._ntree_groups;
                preds[2] = 1.0 - preds[1];
            } else {
                int n2;
                double d2 = 0.0;
                for (n2 = 1; n2 <= this._nclasses; ++n2) {
                    d2 += preds[n2];
                }
                if (d2 > 0.0) {
                    n2 = 1;
                    while (n2 <= this._nclasses) {
                        int n3 = n2++;
                        preds[n3] = preds[n3] / d2;
                    }
                }
            }
            if (this._balanceClasses) {
                GenModel.correctProbabilities(preds, this._priorClassDistrib, this._modelClassDistrib);
            }
            preds[0] = GenModel.getPrediction(preds, this._priorClassDistrib, row, this._defaultThreshold);
        }
        return preds;
    }

    @Override
    public final double[] score0(double[] row, double[] preds) {
        return this.score0(row, 0.0, preds);
    }

    static class ContributionsPredictorDRF
    extends SharedTreeMojoModelWithContributions.SharedTreeContributionsPredictor {
        private final float _featurePlusBiasRatio;
        private final int _normalizer;

        private ContributionsPredictorDRF(DrfMojoModel model, TreeSHAPPredictor<double[]> treeSHAPPredictor) {
            super(model, treeSHAPPredictor);
            if (ModelCategory.Regression.equals((Object)model._category)) {
                this._featurePlusBiasRatio = 0.0f;
                this._normalizer = model._ntree_groups;
                return;
            }
            if (ModelCategory.Binomial.equals((Object)model._category)) {
                this._featurePlusBiasRatio = 1.0f / (float)(model._nfeatures + 1);
                this._normalizer = -model._ntree_groups;
                return;
            }
            throw new UnsupportedOperationException("Model category " + (Object)((Object)model._category) + " cannot be used to calculate feature contributions.");
        }

        @Override
        public float[] getContribs(float[] contribs) {
            for (int i2 = 0; i2 < contribs.length; ++i2) {
                contribs[i2] = this._featurePlusBiasRatio + contribs[i2] / (float)this._normalizer;
            }
            return contribs;
        }
    }
}

