/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.gam;

import hex.genmodel.algos.gam.GamMojoModelBase;
import hex.genmodel.utils.DistributionFamily;

public class GamMojoMultinomialModel
extends GamMojoModelBase {
    private boolean _trueMultinomial;

    GamMojoMultinomialModel(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    @Override
    void init() {
        super.init();
        this._trueMultinomial = this._family.equals((Object)DistributionFamily.multinomial);
    }

    @Override
    double[] gamScore0(double[] row, double[] preds) {
        this._beta_multinomial = row.length == this.nfeatures() ? this._beta_multinomial_center : this._beta_multinomial_no_center;
        for (int i2 = 0; i2 < this._nclasses; ++i2) {
            GamMojoMultinomialModel gamMojoMultinomialModel = this;
            preds[i2 + 1] = gamMojoMultinomialModel.generateEta(gamMojoMultinomialModel._beta_multinomial[i2], row);
        }
        if (this._trueMultinomial) {
            return this.postPredMultinomial(preds);
        }
        return this.postPredOrdinal(preds);
    }

    double[] postPredMultinomial(double[] preds) {
        int n2;
        double d2 = 0.0;
        double d3 = 0.0;
        for (n2 = 1; n2 < preds.length; ++n2) {
            if (!(preds[n2] > d2)) continue;
            d2 = preds[n2];
        }
        for (n2 = 1; n2 < preds.length; ++n2) {
            preds[n2] = Math.exp(preds[n2] - d2);
            d3 += preds[n2];
        }
        d3 = 1.0 / d3;
        double d4 = 0.0;
        for (int i2 = 1; i2 < preds.length; ++i2) {
            int n3 = i2;
            double d5 = preds[n3] = preds[n3] * d3;
            if (!(d5 > d4)) continue;
            d4 = preds[i2];
            preds[0] = i2 - 1;
        }
        return preds;
    }

    double[] postPredOrdinal(double[] preds) {
        double d2;
        int n2;
        double d3 = 0.0;
        preds[0] = this._lastClass;
        for (n2 = 0; n2 < this._lastClass; ++n2) {
            d2 = preds[n2 + 1];
            double d4 = 1.0 / (1.0 + Math.exp(-d2));
            preds[n2 + 1] = d4 - d3;
            d3 = d4;
            if (!(d2 > 0.0)) continue;
            preds[0] = n2;
            break;
        }
        for (n2 = (int)preds[0] + 1; n2 < this._lastClass; ++n2) {
            d2 = 1.0 / (1.0 + Math.exp(-preds[n2 + 1]));
            preds[n2 + 1] = d2 - d3;
            d3 = d2;
        }
        preds[this._nclasses] = 1.0 - d3;
        return preds;
    }
}

