/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.gbm;

import hex.genmodel.GenModel;
import hex.genmodel.PredictContributions;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.SharedTreeMojoModelWithContributions;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;

public final class GbmMojoModel
extends SharedTreeMojoModelWithContributions
implements SharedTreeGraphConverter {
    public DistributionFamily _family;
    public LinkFunctionType _link_function;
    public double _init_f;

    public GbmMojoModel(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    @Override
    protected final PredictContributions getContributionsPredictor(TreeSHAPPredictor<double[]> treeSHAPPredictor) {
        return new SharedTreeMojoModelWithContributions.SharedTreeContributionsPredictor(this, treeSHAPPredictor);
    }

    @Override
    public final double getInitF() {
        return this._init_f;
    }

    @Override
    public final double[] score0(double[] row, double offset, double[] preds) {
        super.scoreAllTrees(row, preds);
        return this.unifyPreds(row, offset, preds);
    }

    @Override
    public final double[] unifyPreds(double[] row, double offset, double[] preds) {
        if (this._family == DistributionFamily.bernoulli || this._family == DistributionFamily.quasibinomial || this._family == DistributionFamily.modified_huber) {
            double d2 = preds[1] + this._init_f + offset;
            GbmMojoModel gbmMojoModel = this;
            preds[2] = gbmMojoModel.linkInv(gbmMojoModel._link_function, d2);
            preds[1] = 1.0 - preds[2];
        } else if (this._family == DistributionFamily.multinomial) {
            if (this._nclasses == 2) {
                preds[1] = preds[1] + (this._init_f + offset);
                preds[2] = -preds[1];
            }
            GenModel.GBM_rescale(preds);
        } else {
            double d3 = preds[0] + this._init_f + offset;
            GbmMojoModel gbmMojoModel = this;
            preds[0] = gbmMojoModel.linkInv(gbmMojoModel._link_function, d3);
            return preds;
        }
        if (this._balanceClasses) {
            GenModel.correctProbabilities(preds, this._priorClassDistrib, this._modelClassDistrib);
        }
        preds[0] = GenModel.getPrediction(preds, this._priorClassDistrib, row, this._defaultThreshold);
        return preds;
    }

    private double linkInv(LinkFunctionType linkFunction, double f2) {
        switch (linkFunction) {
            case log: {
                return GbmMojoModel.exp(f2);
            }
            case logit: 
            case ologit: {
                return 1.0 / (1.0 + GbmMojoModel.exp(-f2));
            }
            case ologlog: {
                return 1.0 - GbmMojoModel.exp(-1.0 * GbmMojoModel.exp(f2));
            }
            case oprobit: {
                return 0.0;
            }
            case inverse: {
                double d2 = f2 < 0.0 ? Math.min(-1.0E-5, f2) : Math.max(-1.0E-5, f2);
                return 1.0 / d2;
            }
        }
        return f2;
    }

    public static double exp(double x2) {
        return Math.min(1.0E19, Math.exp(x2));
    }

    public static double log(double x2) {
        if ((x2 = Math.max(0.0, x2)) == 0.0) {
            return -19.0;
        }
        return Math.max(-19.0, Math.log(x2));
    }

    @Override
    public final double[] score0(double[] row, double[] preds) {
        return this.score0(row, 0.0, preds);
    }

    public final String[] leaf_node_assignment(double[] row) {
        return this.getDecisionPath(row);
    }

    @Override
    public final String[] getOutputNames() {
        if (this._family == DistributionFamily.quasibinomial) {
            GbmMojoModel gbmMojoModel = this;
            if (gbmMojoModel.getDomainValues(gbmMojoModel.getResponseIdx()) == null) {
                return new String[]{"predict", "pVal0", "pVal1"};
            }
        }
        return super.getOutputNames();
    }
}

