/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.gbm;

import hex.genmodel.GenModel;
import hex.genmodel.PredictContributions;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.SharedTreeMojoModelWithContributions;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;

public final class GbmMojoModel
extends SharedTreeMojoModelWithContributions
implements SharedTreeGraphConverter {
    public DistributionFamily _family;
    public LinkFunctionType _link_function;
    public double _init_f;

    public GbmMojoModel(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    @Override
    protected PredictContributions getContributionsPredictor(TreeSHAPPredictor<double[]> treeSHAPPredictor) {
        return new SharedTreeMojoModelWithContributions.SharedTreeContributionsPredictor(this, treeSHAPPredictor);
    }

    @Override
    public double getInitF() {
        return this._init_f;
    }

    @Override
    public final double[] score0(double[] row, double offset, double[] preds) {
        super.scoreAllTrees(row, preds);
        return this.unifyPreds(row, offset, preds);
    }

    @Override
    public final double[] unifyPreds(double[] row, double offset, double[] preds) {
        if (this._family == DistributionFamily.bernoulli || this._family == DistributionFamily.quasibinomial || this._family == DistributionFamily.modified_huber) {
            double f = preds[1] + this._init_f + offset;
            preds[2] = this.linkInv(this._link_function, f);
            preds[1] = 1.0 - preds[2];
        } else if (this._family == DistributionFamily.multinomial) {
            if (this._nclasses == 2) {
                preds[1] = preds[1] + (this._init_f + offset);
                preds[2] = -preds[1];
            }
            GenModel.GBM_rescale(preds);
        } else {
            double f = preds[0] + this._init_f + offset;
            preds[0] = this.linkInv(this._link_function, f);
            return preds;
        }
        if (this._balanceClasses) {
            GenModel.correctProbabilities(preds, this._priorClassDistrib, this._modelClassDistrib);
        }
        preds[0] = GenModel.getPrediction(preds, this._priorClassDistrib, row, this._defaultThreshold);
        return preds;
    }

    private double linkInv(LinkFunctionType linkFunction, double f) {
        switch (linkFunction) {
            case log: {
                return GbmMojoModel.exp(f);
            }
            case logit: 
            case ologit: {
                return 1.0 / (1.0 + GbmMojoModel.exp(-f));
            }
            case ologlog: {
                return 1.0 - GbmMojoModel.exp(-1.0 * GbmMojoModel.exp(f));
            }
            case oprobit: {
                return 0.0;
            }
            case inverse: {
                double xx = f < 0.0 ? Math.min(-1.0E-5, f) : Math.max(-1.0E-5, f);
                return 1.0 / xx;
            }
        }
        return f;
    }

    public static double exp(double x) {
        return Math.min(1.0E19, Math.exp(x));
    }

    public static double log(double x) {
        return (x = Math.max(0.0, x)) == 0.0 ? -19.0 : Math.max(-19.0, Math.log(x));
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        return this.score0(row, 0.0, preds);
    }

    public String[] leaf_node_assignment(double[] row) {
        return this.getDecisionPath(row);
    }

    @Override
    public String[] getOutputNames() {
        if (this._family == DistributionFamily.quasibinomial && this.getDomainValues(this.getResponseIdx()) == null) {
            return new String[]{"predict", "pVal0", "pVal1"};
        }
        return super.getOutputNames();
    }
}

