/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.gam;

import hex.genmodel.ConverterFactoryProvidingModel;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.gam.GamRowToRawDataConverter;
import hex.genmodel.algos.gam.GamUtilsCubicRegression;
import hex.genmodel.algos.gam.GamUtilsThinPlateRegression;
import hex.genmodel.easy.CategoricalEncoder;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.RowToRawDataConverter;
import hex.genmodel.utils.ArrayUtils;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;
import java.util.Map;

public abstract class GamMojoModelBase
extends MojoModel
implements ConverterFactoryProvidingModel,
Cloneable {
    public LinkFunctionType _link_function;
    boolean _useAllFactorLevels;
    int _cats;
    int[] _catNAFills;
    int[] _catOffsets;
    int _nums;
    int _numsCenter;
    double[] _numNAFillsCenter;
    boolean _meanImputation;
    double[] _beta;
    double[] _beta_no_center;
    double[] _beta_center;
    double[][] _beta_multinomial;
    double[][] _beta_multinomial_no_center;
    double[][] _beta_multinomial_center;
    DistributionFamily _family;
    String[][] _gam_columns;
    String[][] _gam_columns_sorted;
    int[] _d;
    int[] _m;
    int[] _M;
    int[] _gamPredSize;
    int _num_gam_columns;
    int[] _bs;
    int[] _bs_sorted;
    int[] _num_knots;
    int[] _num_knots_sorted;
    int[] _num_knots_sorted_minus1;
    int[] _num_knots_TP;
    double[][][] _knots;
    double[][][] _binvD;
    double[][][] _zTranspose;
    double[][][] _zTransposeCS;
    String[][] _gamColNames;
    String[][] _gamColNamesCenter;
    String[] _names_no_centering;
    int _totFeatureSize;
    int _betaSizePerClass;
    int _betaCenterSizePerClass;
    double _tweedieLinkPower;
    double[][] _basisVals;
    double[][] _basisValsCenter;
    double[][] _hj;
    int _numExpandedGamCols;
    int _numExpandedGamColsCenter;
    int _lastClass;
    int[][][] _allPolyBasisList;
    int _num_TP_col;
    int _num_CS_col;
    double[][] _tpRowVals;
    double[][] _tpDistance;
    double[][] _tpDistzCS;
    double[][] _tpPoly;
    double[][] _tpDistzCSPoly;
    double[][] _tpDistzCSPolyzT;
    boolean[] _dEven;
    double[] _constantTerms;
    double[][] _gamColMeansRaw;
    double[][] _oneOGamColStd;
    boolean _standardize;

    GamMojoModelBase(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        if (this._meanImputation) {
            this.imputeMissingWithMeans(row);
        }
        return this.gamScore0(row, preds);
    }

    void init() {
        int n2;
        this._num_knots_sorted_minus1 = new int[this._num_knots_sorted.length];
        for (n2 = 0; n2 < this._num_knots_sorted.length; ++n2) {
            this._num_knots_sorted_minus1[n2] = this._num_knots_sorted[n2] - 1;
        }
        if (this._num_CS_col > 0) {
            this._basisVals = new double[this._num_CS_col][];
            this._basisValsCenter = new double[this._num_CS_col][];
            this._hj = new double[this._num_CS_col][];
            for (n2 = 0; n2 < this._num_CS_col; ++n2) {
                this._basisVals[n2] = new double[this._num_knots_sorted[n2]];
                this._basisValsCenter[n2] = new double[this._num_knots_sorted_minus1[n2]];
                this._hj[n2] = ArrayUtils.eleDiff(this._knots[n2][0]);
            }
        }
        if (this._num_TP_col > 0) {
            this._tpRowVals = new double[this._num_TP_col][];
            this._tpDistance = new double[this._num_TP_col][];
            this._tpDistzCS = new double[this._num_TP_col][];
            this._tpPoly = new double[this._num_TP_col][];
            this._tpDistzCSPoly = new double[this._num_TP_col][];
            this._tpDistzCSPolyzT = new double[this._num_TP_col][];
            this._dEven = new boolean[this._num_TP_col];
            this._constantTerms = new double[this._num_TP_col];
            for (n2 = 0; n2 < this._num_TP_col; ++n2) {
                int n3 = n2 + this._num_CS_col;
                this._tpRowVals[n2] = new double[this._d[n3]];
                this._tpDistance[n2] = new double[this._num_knots_sorted[n3]];
                this._tpDistzCS[n2] = new double[this._num_knots_sorted[n3] - this._M[n2]];
                this._tpPoly[n2] = new double[this._M[n2]];
                this._tpDistzCSPoly[n2] = new double[this._num_knots_sorted[n3]];
                this._tpDistzCSPolyzT[n2] = new double[this._num_knots_sorted[n3] - 1];
                this._dEven[n2] = this._d[n3] % 2 == 0;
                this._constantTerms[n2] = GamUtilsThinPlateRegression.calTPConstantTerm(this._m[n2], this._d[n3], this._dEven[n2]);
            }
        }
        this._lastClass = this._nclasses - 1;
    }

    @Override
    public GenModel internal_threadSafeInstance() {
        try {
            GamMojoModelBase gamMojoModelBase = (GamMojoModelBase)this.clone();
            gamMojoModelBase.init();
            return gamMojoModelBase;
        }
        catch (CloneNotSupportedException cloneNotSupportedException) {
            throw new RuntimeException(cloneNotSupportedException);
        }
    }

    abstract double[] gamScore0(double[] var1, double[] var2);

    private void imputeMissingWithMeans(double[] data) {
        int n2;
        for (n2 = 0; n2 < this._cats; ++n2) {
            if (!Double.isNaN(data[n2])) continue;
            data[n2] = this._catNAFills[n2];
        }
        for (n2 = 0; n2 < this._numsCenter; ++n2) {
            if (!Double.isNaN(data[n2 + this._cats])) continue;
            data[n2 + this._cats] = this._numNAFillsCenter[n2];
        }
    }

    double evalLink(double val) {
        switch (this._link_function) {
            case identity: {
                return GenModel.GLM_identityInv(val);
            }
            case logit: {
                return GenModel.GLM_logitInv(val);
            }
            case log: {
                return GenModel.GLM_logInv(val);
            }
            case inverse: {
                return GenModel.GLM_inverseInv(val);
            }
            case tweedie: {
                return GenModel.GLM_tweedieInv(val, this._tweedieLinkPower);
            }
        }
        throw new UnsupportedOperationException("Unexpected link function " + (Object)((Object)this._link_function));
    }

    int readCatVal(double data, int dataIndex) {
        int n2 = this._useAllFactorLevels ? (int)data : (int)data - 1;
        if (n2 < 0) {
            return -1;
        }
        return n2 += this._catOffsets[dataIndex];
    }

    double generateEta(double[] beta, double[] data) {
        int n2;
        int n3;
        double d2 = 0.0;
        int n4 = this._catOffsets.length - 1;
        for (n3 = 0; n3 < n4; ++n3) {
            n2 = this.readCatVal(data[n3], n3);
            if (n2 >= this._catOffsets[n3 + 1] || n2 < 0) continue;
            d2 += beta[n2];
        }
        n3 = this._catOffsets[this._cats] - this._cats;
        n2 = beta.length - 1 - n3;
        for (int i2 = this._cats; i2 < n2; ++i2) {
            d2 += beta[n3 + i2] * data[i2];
        }
        return d2 += beta[beta.length - 1];
    }

    private boolean gamificationNeeded(double[] rawData, int gamColStart) {
        for (int i2 = gamColStart; i2 < rawData.length; ++i2) {
            if (Double.isNaN(rawData[i2])) continue;
            return false;
        }
        return true;
    }

    double[] addExpandGamCols(double[] rawData, RowData rowData) {
        int n2 = this._nfeatures - this._numExpandedGamColsCenter;
        if (!this.gamificationNeeded(rawData, n2)) {
            return rawData;
        }
        double[] dArray = ArrayUtils.nanArray(this._nfeatures);
        System.arraycopy(rawData, 0, dArray, 0, n2);
        int n3 = 0;
        for (int i2 = 0; i2 < this._num_gam_columns; ++i2) {
            if (this._bs_sorted[i2] == 0) {
                Object v2 = rowData.get(this._gam_columns_sorted[i2][0]);
                if (v2 != null) {
                    double d2 = v2 instanceof String ? Double.parseDouble((String)v2) : (Double)v2;
                    GamUtilsCubicRegression.expandOneGamCol(d2, this._binvD[i2], this._basisVals[i2], this._hj[i2], this._knots[i2][0]);
                    ArrayUtils.multArray(this._basisVals[i2], this._zTranspose[i2], this._basisValsCenter[i2]);
                    System.arraycopy(this._basisValsCenter[i2], 0, dArray, n2, this._num_knots_sorted_minus1[i2]);
                }
            } else if (this._bs_sorted[i2] == 1) {
                String[] stringArray = this._gam_columns_sorted[i2];
                int n4 = i2 - this._num_CS_col;
                double[] dArray2 = this.grabPredictorVals(stringArray, rowData, this._tpRowVals[n4]);
                if (dArray2 != null) {
                    GamUtilsThinPlateRegression.calculateDistance(this._tpDistance[n3], dArray2, this._num_knots_sorted[i2], this._knots[i2], this._d[i2], this._m[n3], this._dEven[n3], this._constantTerms[n3], this._oneOGamColStd[n3], this._standardize);
                    ArrayUtils.multArray(this._tpDistance[n3], this._zTransposeCS[n3], this._tpDistzCS[n3]);
                    GamUtilsThinPlateRegression.calculatePolynomialBasis(this._tpPoly[n3], dArray2, this._d[i2], this._M[n3], this._allPolyBasisList[n3], this._gamColMeansRaw[n3], this._oneOGamColStd[n3], this._standardize);
                    System.arraycopy(this._tpDistzCS[n3], 0, this._tpDistzCSPoly[n3], 0, this._tpDistzCS[n3].length);
                    System.arraycopy(this._tpPoly[n3], 0, this._tpDistzCSPoly[n3], this._tpDistzCS[n3].length, this._M[n3]);
                    ArrayUtils.multArray(this._tpDistzCSPoly[n3], this._zTranspose[i2], this._tpDistzCSPolyzT[n3]);
                    System.arraycopy(this._tpDistzCSPolyzT[n3], 0, dArray, n2, this._tpDistzCSPolyzT[n3].length);
                    ++n3;
                }
            } else {
                throw new IllegalArgumentException("spline type not implemented!");
            }
            n2 += this._num_knots_sorted_minus1[i2];
        }
        return dArray;
    }

    double[] grabPredictorVals(String[] gamCols, RowData rowData, double[] predVals) {
        int n2 = gamCols.length;
        for (int i2 = 0; i2 < n2; ++i2) {
            Object v2 = rowData.get(gamCols[i2]);
            if (v2 == null) {
                return null;
            }
            predVals[i2] = v2 instanceof String ? Double.parseDouble((String)v2) : (Double)v2;
        }
        return predVals;
    }

    @Override
    public RowToRawDataConverter makeConverterFactory(Map<String, Integer> modelColumnNameToIndexMap, Map<Integer, CategoricalEncoder> domainMap, EasyPredictModelWrapper.ErrorConsumer errorConsumer, EasyPredictModelWrapper.Config config) {
        return new GamRowToRawDataConverter(this, modelColumnNameToIndexMap, domainMap, errorConsumer, config);
    }
}

