/*
 * Decompiled with CFR 0.152.
 */
package hex.gam;

import hex.DataInfo;
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.StringPair;
import hex.VarImp;
import hex.deeplearning.DeepLearningModel;
import hex.gam.GAM;
import hex.gam.GAMMojoWriter;
import hex.gam.MatrixFrameUtils.AddCSGamColumns;
import hex.gam.MatrixFrameUtils.AddISGamColumns;
import hex.gam.MatrixFrameUtils.AddTPKnotsGamColumns;
import hex.gam.MatrixFrameUtils.GamUtils;
import hex.gam.MetricBuilderGAM;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.util.DistributionUtils;
import hex.util.EffectiveParametersUtils;
import java.io.Serializable;
import java.util.Arrays;
import water.AutoBuffer;
import water.Futures;
import water.Job;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.MemoryManager;
import water.Scope;
import water.exceptions.H2OColumnNotFoundArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.udf.CFuncRef;
import water.util.ArrayUtils;
import water.util.IcedHashSet;
import water.util.Log;
import water.util.TwoDimTable;
import water.util.VecUtils;

public class GAMModel
extends Model<GAMModel, GAMParameters, GAMModelOutput> {
    private static final String[] BINOMIAL_CLASS_NAMES = new String[]{"0", "1"};
    public String[][] _gamColNamesNoCentering;
    public String[][] _gamColNames;
    public int[] _gamPredSize;
    public int[] _m;
    public int[] _M;
    public int _cubicSplineNum;
    public int _iSplineNum;
    public int _thinPlateSmoothersWithKnotsNum;
    public Key<Frame>[] _gamFrameKeysCenter;
    public double[] _gamColMeans;
    public int _nclass;
    public double[] _ymu;
    public long _nobs;
    public long _nullDOF;
    public int _rank;
    public IcedHashSet<Key<Frame>> _validKeys = null;

    @Override
    public String[] makeScoringNames() {
        String[] names = super.makeScoringNames();
        if (((GAMModelOutput)this._output)._glm_vcov != null) {
            names = ArrayUtils.append(names, "StdErr");
        }
        return names;
    }

    @Override
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        if (domain == null && (((GAMParameters)this._parms)._family == GLMModel.GLMParameters.Family.binomial || ((GAMParameters)this._parms)._family == GLMModel.GLMParameters.Family.quasibinomial || ((GAMParameters)this._parms)._family == GLMModel.GLMParameters.Family.negativebinomial || ((GAMParameters)this._parms)._family == GLMModel.GLMParameters.Family.fractionalbinomial)) {
            domain = ((GAMParameters)this._parms)._family == GLMModel.GLMParameters.Family.fractionalbinomial ? BINOMIAL_CLASS_NAMES : ((GAMModelOutput)this._output)._responseDomains;
        }
        GLMModel.GLMWeightsFun glmf = new GLMModel.GLMWeightsFun(((GAMParameters)this._parms)._family, ((GAMParameters)this._parms)._link, ((GAMParameters)this._parms)._tweedie_variance_power, ((GAMParameters)this._parms)._tweedie_link_power, ((GAMParameters)this._parms)._theta);
        return new MetricBuilderGAM(domain, this._ymu, glmf, this._rank, true, ((GAMParameters)this._parms)._intercept, this._nclass, ((GAMParameters)this._parms)._auc_type);
    }

    public GAMModel(Key<GAMModel> selfKey, GAMParameters parms, GAMModelOutput output) {
        super(selfKey, parms, output);
        assert (Arrays.equals(this._key._kb, selfKey._kb));
    }

    public void initActualParamValuesAfterGlmCreation() {
        EffectiveParametersUtils.initFoldAssignment(this._parms);
    }

    public TwoDimTable genCoefficientMagTableMultinomial(String[] colHeaders, double[][] coefficients, String[] coefficientNames, String tableHeader) {
        String[] colTypes = new String[]{"double", "string"};
        String[] colFormat = new String[]{"%5f", ""};
        int nCoeff = coefficients[0].length;
        int nClass = coefficients.length;
        String[] coeffNames = new String[nCoeff - 1];
        String[] coeffNames2 = new String[coeffNames.length];
        double[] coeffMags = new double[coeffNames.length];
        double[] coeffMags2 = new double[coeffNames.length];
        String[] coeffSigns = new String[coeffNames.length];
        Log.info("genCoefficientMagTableMultinomial", String.format("coeffNames length: %d.  coeffMags length: %d, coeffSigns length: %d", coeffNames.length, coeffMags.length, coeffSigns.length));
        int countIndex = 0;
        for (int index = 0; index < nCoeff; ++index) {
            if (coefficientNames[index].equals("Intercept")) continue;
            for (int classInd = 0; classInd < nClass; ++classInd) {
                int n2 = countIndex;
                coeffMags[n2] = coeffMags[n2] + Math.abs(coefficients[classInd][index]);
            }
            coeffNames[countIndex] = coefficientNames[index];
            coeffSigns[countIndex] = "POS";
            ++countIndex;
        }
        Integer[] indices = GamUtils.sortCoeffMags(coeffMags.length, coeffMags);
        for (int index = 0; index < coeffMags.length; ++index) {
            coeffMags2[index] = coeffMags[indices[index]];
            coeffNames2[index] = coeffNames[indices[index]];
        }
        Log.info("genCoefficientMagTableMultinomial", String.format("coeffNames2 length: %d.  coeffMags2 length: %d, coeffSigns length: %d", coeffNames2.length, coeffMags2.length, coeffSigns.length));
        TwoDimTable table = new TwoDimTable(tableHeader, "Standardized Coefficient Magnitutes", coeffNames2, colHeaders, colTypes, colFormat, "names");
        this.fillUpCoeffsMag(coeffMags2, coeffSigns, table, 0);
        return table;
    }

    public TwoDimTable genCoefficientMagTable(String[] colHeaders, double[] coefficients, String[] coefficientNames, String tableHeader) {
        String[] colTypes = new String[]{"double", "string"};
        String[] colFormat = new String[]{"%5f", ""};
        int nCoeff = coefficients.length;
        String[] coeffNames = new String[nCoeff - 1];
        double[] coeffMags = new double[nCoeff - 1];
        String[] coeffSigns = new String[nCoeff - 1];
        int countMagIndex = 0;
        for (int index = 0; index < nCoeff; ++index) {
            if (coefficientNames[index].equals("Intercept")) continue;
            coeffMags[countMagIndex] = Math.abs(coefficients[index]);
            coeffSigns[countMagIndex] = coefficients[index] > 0.0 ? "POS" : "NEG";
            coeffNames[countMagIndex++] = coefficientNames[index];
        }
        Integer[] indices = GamUtils.sortCoeffMags(coeffMags.length, coeffMags);
        String[] names2 = new String[coeffNames.length];
        double[] mag2 = new double[coeffNames.length];
        String[] sign2 = new String[coeffNames.length];
        for (int i2 = 0; i2 < coeffNames.length; ++i2) {
            names2[i2] = coeffNames[indices[i2]];
            mag2[i2] = coeffMags[indices[i2]];
            sign2[i2] = coeffSigns[indices[i2]];
        }
        Log.info("genCoefficientMagTableMultinomial", String.format("coeffNames length: %d.  coeffMags length: %d, coeffSigns length: %d", coeffNames.length, coeffMags.length, coeffSigns.length));
        TwoDimTable table = new TwoDimTable(tableHeader, "", names2, colHeaders, colTypes, colFormat, "names");
        this.fillUpCoeffsMag(mag2, sign2, table, 0);
        return table;
    }

    private void fillUpCoeffsMag(double[] coeffMags, String[] coeffSigns, TwoDimTable tdt, int rowStart) {
        int arrLength = coeffMags.length + rowStart;
        int arrCounter = 0;
        for (int i2 = rowStart; i2 < arrLength; ++i2) {
            tdt.set(i2, 0, coeffMags[arrCounter]);
            tdt.set(i2, 1, coeffSigns[arrCounter]);
            ++arrCounter;
        }
    }

    @Override
    protected ModelMetrics.MetricBuilder scoreMetrics(Frame adaptFrm) {
        GAMScore gs = this.makeScoringTask(adaptFrm, false, null, true);
        assert (((GAMScore)gs)._dinfo._valid) : "_valid flag should be set on data info when doing scoring";
        return ((GAMScore)gs.doAll(((GAMScore)gs)._dinfo._adaptedFrame))._mb;
    }

    @Override
    protected String[][] scoringDomains() {
        int responseColIdx = ((GAMModelOutput)this._output)._dinfo.responseChunkId(0);
        String[][] domains = ((GAMModelOutput)this._output)._domains;
        if ((((GAMParameters)this._parms)._family == GLMModel.GLMParameters.Family.binomial || ((GAMParameters)this._parms)._family == GLMModel.GLMParameters.Family.quasibinomial || ((GAMParameters)this._parms)._family == GLMModel.GLMParameters.Family.fractionalbinomial) && ((GAMModelOutput)this._output)._domains[responseColIdx] == null) {
            domains = (String[][])domains.clone();
            domains[responseColIdx] = ((GAMParameters)this._parms)._family == GLMModel.GLMParameters.Family.fractionalbinomial ? BINOMIAL_CLASS_NAMES : ((GAMModelOutput)this._output)._responseDomains;
        }
        return domains;
    }

    @Override
    public String[] adaptTestForTrain(Frame test, boolean expensive, boolean computeMetrics) {
        String[] testNames = test.names();
        if (!GamUtils.equalColNames(testNames, ((GAMModelOutput)this._output)._dinfo._adaptedFrame.names(), ((GAMParameters)this._parms)._response_column)) {
            Frame adptedF = this.cleanUpInputFrame(test);
            int testNumCols = test.numCols();
            for (int index = 0; index < testNumCols; ++index) {
                test.remove(0);
            }
            int adaptNumCols = adptedF.numCols();
            for (int index = 0; index < adaptNumCols; ++index) {
                test.add(adptedF.name(index), adptedF.vec(index));
            }
            return super.adaptTestForTrain(test, expensive, computeMetrics);
        }
        return super.adaptTestForTrain(test, expensive, computeMetrics);
    }

    public Frame cleanUpInputFrame(Frame test) {
        Frame adptedF = new Frame(Key.make(), test.names(), (Vec[])test.vecs().clone());
        return GAMModel.cleanUpInputFrame(adptedF, (GAMParameters)this._parms, this._gamColNames, ((GAMModelOutput)this._output)._binvD, ((GAMModelOutput)this._output)._zTranspose, ((GAMModelOutput)this._output)._knots, ((GAMModelOutput)this._output)._zTransposeCS, ((GAMModelOutput)this._output)._allPolyBasisList, ((GAMModelOutput)this._output)._gamColMeansRaw, ((GAMModelOutput)this._output)._oneOGamColStd, this._cubicSplineNum);
    }

    public static Frame cleanUpInputFrame(Frame adptedF, GAMParameters parms, String[][] gamColNames, double[][][] binvD, double[][][] zTranspose, double[][][] knots, double[][][] zTransposeCS, int[][][] polyBasisList, double[][] gamColMeansRaw, double[][] oneOGamColStd, int numCSGamCols) {
        String[] testNames = adptedF.names();
        Frame csAugmentedColumns = GAMModel.addSingleVariableGamColumns(adptedF, parms, gamColNames, binvD, zTranspose, knots, numCSGamCols);
        Frame tpAugmentedColumns = GAMModel.addTPGamColumns(adptedF, parms, zTransposeCS, zTranspose, polyBasisList, knots, gamColMeansRaw, oneOGamColStd);
        if (csAugmentedColumns == null) {
            csAugmentedColumns = tpAugmentedColumns;
        } else if (tpAugmentedColumns != null) {
            csAugmentedColumns.add(tpAugmentedColumns.names(), tpAugmentedColumns.removeAll());
        }
        if (parms._ignored_columns != null) {
            for (String iname : parms._ignored_columns) {
                if (!ArrayUtils.contains(testNames, iname)) continue;
                adptedF.remove(iname);
            }
        }
        Vec respV = null;
        Vec weightV = null;
        if (parms._weights_column != null) {
            weightV = adptedF.remove(parms._weights_column);
        }
        if (ArrayUtils.contains(testNames, parms._response_column)) {
            respV = adptedF.remove(parms._response_column);
        }
        adptedF.add(csAugmentedColumns.names(), csAugmentedColumns.removeAll());
        Scope.track(csAugmentedColumns);
        if (weightV != null) {
            adptedF.add(parms._weights_column, weightV);
        }
        if (respV != null) {
            adptedF.add(parms._response_column, respV);
        }
        return adptedF;
    }

    public static Frame adaptValidFrame(Frame adptedF, Frame valid, GAMParameters parms, String[][] gamColNames, double[][][] binvD, double[][][] zTranspose, double[][][] knots, double[][][] zTransposeCS, int[][][] polyBasisList, double[][] gamColMeansRaw, double[][] oneOGamColStd, int numCSGam) {
        Frame csAugmentedColumns = GAMModel.addSingleVariableGamColumns(adptedF, parms, gamColNames, binvD, zTranspose, knots, numCSGam);
        Frame tpAugmentedColumns = GAMModel.addTPGamColumns(adptedF, parms, zTransposeCS, zTranspose, polyBasisList, knots, gamColMeansRaw, oneOGamColStd);
        if (csAugmentedColumns == null) {
            csAugmentedColumns = tpAugmentedColumns;
        } else if (tpAugmentedColumns != null) {
            csAugmentedColumns.add(tpAugmentedColumns.names(), tpAugmentedColumns.removeAll());
        }
        Vec respV = null;
        Vec weightV = null;
        if (parms._weights_column != null) {
            weightV = valid.remove(parms._weights_column);
        }
        if (ArrayUtils.contains(valid.names(), parms._response_column)) {
            respV = valid.remove(parms._response_column);
        }
        valid.add(csAugmentedColumns.names(), csAugmentedColumns.removeAll());
        Scope.track(csAugmentedColumns);
        if (weightV != null) {
            valid.add(parms._weights_column, weightV);
        }
        if (respV != null) {
            valid.add(parms._response_column, respV);
        }
        return valid;
    }

    public static Frame addTPGamColumns(Frame adaptedF, GAMParameters parms, double[][][] zTransposeCS, double[][][] zTranspose, int[][][] polyBasisList, double[][][] knots, double[][] gamColMeansRaw, double[][] oneOColStd) {
        int numTPCols;
        int n2 = numTPCols = parms._M == null ? 0 : parms._M.length;
        if (numTPCols == 0) {
            return null;
        }
        AddTPKnotsGamColumns addTPCols = new AddTPKnotsGamColumns(parms, zTransposeCS, zTranspose, polyBasisList, knots, adaptedF);
        addTPCols.addTPGamCols(gamColMeansRaw, oneOColStd);
        return GamUtils.concateGamVecs(addTPCols._gamFrameKeysCenter);
    }

    public static Frame addSingleVariableGamColumns(Frame adptedF, GAMParameters parms, String[][] gamColNames, double[][][] binvD, double[][][] zTranspose, double[][][] knots, int numCSGamCol) {
        int offset;
        int numGamCols = parms._gam_columns.length;
        int numISplineGamCol = numGamCols - numCSGamCol - (parms._M == null ? 0 : parms._M.length);
        int numSingleVariableGamCols = numISplineGamCol + numCSGamCol;
        if (numSingleVariableGamCols == 0) {
            return null;
        }
        Vec[] gamColCSSplines = new Vec[numCSGamCol];
        Vec[] gamColISplines = new Vec[numISplineGamCol];
        String[] gamColCSNames = new String[numCSGamCol];
        String[] gamColISplineNames = new String[numISplineGamCol];
        int countCS = 0;
        int countIS = 0;
        for (int vind = 0; vind < numSingleVariableGamCols; ++vind) {
            if (adptedF.vec(parms._gam_columns_sorted[vind][0]) == null) {
                throw new H2OColumnNotFoundArgumentException("gam_columns", adptedF, parms._gam_columns_sorted[vind][0]);
            }
            if (parms._bs_sorted[vind] == 0) {
                gamColCSSplines[countCS] = (Vec)adptedF.vec(parms._gam_columns_sorted[vind][0]).clone();
                gamColCSNames[countCS++] = parms._gam_columns_sorted[vind][0];
                continue;
            }
            if (parms._bs_sorted[vind] != 2) continue;
            gamColISplines[countIS] = (Vec)adptedF.vec(parms._gam_columns_sorted[vind][0]).clone();
            gamColISplineNames[countIS++] = parms._gam_columns_sorted[vind][0];
        }
        Frame gamifiedCSCols = null;
        Frame gamifiedISCols = null;
        if (numCSGamCol > 0) {
            Frame onlyGamColsCS = new Frame(gamColCSNames, gamColCSSplines);
            AddCSGamColumns genCSGamCols = new AddCSGamColumns(binvD, zTranspose, knots, parms._num_knots_sorted, onlyGamColsCS, parms._bs_sorted);
            genCSGamCols.doAll(genCSGamCols._gamCols2Add, (byte)3, onlyGamColsCS);
            String[] gamColsNamesCS = new String[genCSGamCols._gamCols2Add];
            offset = 0;
            for (int ind = 0; ind < numGamCols; ++ind) {
                if (parms._bs_sorted[ind] != 0) continue;
                System.arraycopy(gamColNames[ind], 0, gamColsNamesCS, offset, gamColNames[ind].length);
                offset += gamColNames[ind].length;
            }
            gamifiedCSCols = genCSGamCols.outputFrame(Key.make(), gamColsNamesCS, null);
        }
        if (numISplineGamCol > 0) {
            Frame onlyGamColsIS = new Frame(gamColISplineNames, gamColISplines);
            AddISGamColumns genISGamCols = new AddISGamColumns(knots, parms._num_knots_sorted, parms._bs_sorted, parms._spline_orders_sorted, onlyGamColsIS);
            genISGamCols.doAll(genISGamCols._totGamifiedColCentered, (byte)3, onlyGamColsIS);
            String[] gamColsNamesIS = new String[genISGamCols._totGamifiedColCentered];
            offset = 0;
            for (int index = 0; index < numGamCols; ++index) {
                if (parms._bs_sorted[index] != 2) continue;
                System.arraycopy(gamColNames[index], 0, gamColsNamesIS, offset, gamColNames[index].length);
                offset += gamColNames[index].length;
            }
            gamifiedISCols = genISGamCols.outputFrame(Key.make(), gamColsNamesIS, null);
        }
        if (gamifiedCSCols == null) {
            return gamifiedISCols;
        }
        if (gamifiedISCols == null) {
            return gamifiedCSCols;
        }
        gamifiedCSCols.add(gamifiedISCols.names(), gamifiedISCols.removeAll());
        Scope.track(gamifiedISCols);
        return gamifiedCSCols;
    }

    public static void removeVec(Vec[] vecs2Remove) {
        if (vecs2Remove == null) {
            return;
        }
        int len = vecs2Remove.length;
        for (int index = 0; index < len; ++index) {
            if (vecs2Remove[index] == null) continue;
            vecs2Remove[index].remove();
        }
    }

    @Override
    protected Model.PredictScoreResult predictScoreImpl(Frame fr, Frame adaptFrm, String destination_key, Job j2, boolean computeMetrics, CFuncRef customMetricFunc) {
        String[] predictNames = this.makeScoringNames();
        String[][] domains = new String[predictNames.length][];
        GAMScore gs = this.makeScoringTask(adaptFrm, true, j2, computeMetrics);
        gs.doAll(predictNames.length, (byte)3, ((GAMScore)gs)._dinfo._adaptedFrame);
        ModelMetrics.MetricBuilder mb = null;
        Frame rawFrame = null;
        if (gs._computeMetrics) {
            mb = gs._mb;
            rawFrame = gs.outputFrame();
        }
        domains[0] = gs._predDomains;
        Frame outputFrame = gs.outputFrame(Key.make(destination_key), predictNames, domains);
        return new Model.PredictScoreResult(this, mb, rawFrame, outputFrame);
    }

    private GAMScore makeScoringTask(Frame adaptFrm, boolean makePredictions, Job j2, boolean computeMetrics) {
        String[] domain;
        boolean detectedComputeMetrics;
        int responseId = adaptFrm.find(((GAMModelOutput)this._output).responseName());
        if (responseId > -1 && adaptFrm.vec(responseId).isBad()) {
            adaptFrm = new Frame(adaptFrm.names(), adaptFrm.vecs());
            adaptFrm.remove(responseId);
        }
        boolean bl = detectedComputeMetrics = computeMetrics && adaptFrm.vec(((GAMModelOutput)this._output).responseName()) != null && !adaptFrm.vec(((GAMModelOutput)this._output).responseName()).isBad();
        String[] stringArray = ((GAMModelOutput)this._output).nclasses() <= 1 ? null : (domain = !detectedComputeMetrics ? ((GAMModelOutput)this._output)._domains[((GAMModelOutput)this._output)._domains.length - 1] : adaptFrm.lastVec().domain());
        if (((GAMParameters)this._parms)._family.equals((Object)GLMModel.GLMParameters.Family.quasibinomial)) {
            domain = ((GAMModelOutput)this._output)._responseDomains;
        }
        return new GAMScore(j2, this, ((GAMModelOutput)this._output)._dinfo.scoringInfo(((GAMModelOutput)this._output)._names, adaptFrm), domain, detectedComputeMetrics, makePredictions);
    }

    @Override
    public double[] score0(double[] data, double[] preds) {
        throw new UnsupportedOperationException("GAMModel.score0 should never be called");
    }

    @Override
    public GAMMojoWriter getMojo() {
        return new GAMMojoWriter(this);
    }

    @Override
    protected Futures remove_impl(Futures fs, boolean cascade) {
        super.remove_impl(fs, cascade);
        Keyed.remove(((GAMModelOutput)this._output)._gamTransformedTrainCenter, fs, true);
        if (this._validKeys != null) {
            for (Key<Frame> oneKey : this._validKeys) {
                Keyed.remove(oneKey, fs, true);
            }
        }
        if (((GAMParameters)this._parms)._keep_cross_validation_predictions) {
            Keyed.remove(((GAMModelOutput)this._output)._cross_validation_holdout_predictions_frame_id, fs, true);
        }
        if (((GAMParameters)this._parms)._keep_cross_validation_fold_assignment) {
            Keyed.remove(((GAMModelOutput)this._output)._cross_validation_fold_assignment_frame_id, fs, true);
        }
        if (((GAMParameters)this._parms)._keep_cross_validation_models && ((GAMModelOutput)this._output)._cross_validation_models != null) {
            for (Key oneModelKey : ((GAMModelOutput)this._output)._cross_validation_models) {
                Keyed.remove(oneModelKey, fs, true);
            }
        }
        return fs;
    }

    @Override
    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        if (((GAMModelOutput)this._output)._gamTransformedTrainCenter != null) {
            ab.putKey(((GAMModelOutput)this._output)._gamTransformedTrainCenter);
        }
        if (((GAMParameters)this._parms)._keep_cross_validation_predictions) {
            ab.putKey(((GAMModelOutput)this._output)._cross_validation_holdout_predictions_frame_id);
        }
        if (((GAMParameters)this._parms)._keep_cross_validation_fold_assignment) {
            ab.putKey(((GAMModelOutput)this._output)._cross_validation_fold_assignment_frame_id);
        }
        if (((GAMParameters)this._parms)._keep_cross_validation_models && ((GAMModelOutput)this._output)._cross_validation_models != null) {
            for (Key oneModelKey : ((GAMModelOutput)this._output)._cross_validation_models) {
                ab.putKey(oneModelKey);
            }
        }
        return super.writeAll_impl(ab);
    }

    @Override
    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        if (((GAMModelOutput)this._output)._gamTransformedTrainCenter != null) {
            ab.getKey(((GAMModelOutput)this._output)._gamTransformedTrainCenter, fs);
        }
        if (((GAMParameters)this._parms)._keep_cross_validation_predictions) {
            ab.getKey(((GAMModelOutput)this._output)._cross_validation_holdout_predictions_frame_id, fs);
        }
        if (((GAMParameters)this._parms)._keep_cross_validation_fold_assignment) {
            ab.getKey(((GAMModelOutput)this._output)._cross_validation_fold_assignment_frame_id, fs);
        }
        if (((GAMParameters)this._parms)._keep_cross_validation_models && ((GAMModelOutput)this._output)._cross_validation_models != null) {
            for (Key oneModelKey : ((GAMModelOutput)this._output)._cross_validation_models) {
                ab.getKey(oneModelKey, fs);
            }
        }
        return super.readAll_impl(ab, fs);
    }

    private class GAMScore
    extends MRTask<GAMScore> {
        private DataInfo _dinfo;
        private double[] _coeffs;
        private double[][] _coeffs_multinomial;
        private int _nclass;
        private boolean _computeMetrics;
        private final Job _j;
        private GLMModel.GLMParameters.Family _family;
        private transient double[] _eta;
        private String[] _predDomains;
        private final GAMModel _m;
        private final double _defaultThreshold;
        private int _lastClass;
        private ModelMetrics.MetricBuilder _mb;
        final boolean _generatePredictions;
        private transient double[][] _vcov;
        private transient double[] _tmp;
        private boolean _classifier2class;

        private GAMScore(Job j2, GAMModel m4, DataInfo dinfo, String[] domain, boolean computeMetrics, boolean makePredictions) {
            this._j = j2;
            this._m = m4;
            this._computeMetrics = computeMetrics;
            this._predDomains = domain;
            this._nclass = ((GAMModelOutput)m4._output).nclasses();
            this._generatePredictions = makePredictions;
            boolean bl = this._classifier2class = ((GAMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.binomial || ((GAMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.quasibinomial || ((GAMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.fractionalbinomial;
            if (((GAMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.multinomial || ((GAMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.ordinal) {
                this._coeffs = null;
                this._coeffs_multinomial = ((GAMModelOutput)m4._output)._model_beta_multinomial;
            } else {
                double[] beta = ((GAMModelOutput)m4._output)._model_beta;
                int[] ids = new int[beta.length - 1];
                int k2 = 0;
                for (int i2 = 0; i2 < beta.length - 1; ++i2) {
                    if (beta[i2] == 0.0) continue;
                    ids[k2++] = i2;
                }
                if (k2 < beta.length - 1) {
                    ids = Arrays.copyOf(ids, k2);
                    dinfo = dinfo.filterExpandedColumns(ids);
                    double[] beta2 = MemoryManager.malloc8d(ids.length + 1);
                    int l2 = 0;
                    for (int x2 : ids) {
                        beta2[l2++] = beta[x2];
                    }
                    beta2[l2] = beta[beta.length - 1];
                    beta = beta2;
                }
                this._coeffs_multinomial = null;
                this._coeffs = beta;
            }
            this._dinfo = dinfo;
            this._dinfo._valid = true;
            this._defaultThreshold = m4.defaultThreshold();
            this._family = ((GAMParameters)m4._parms)._family;
            this._lastClass = this._nclass - 1;
        }

        @Override
        public void map(Chunk[] chks, NewChunk[] nc) {
            if (this.isCancelled() || this._j != null && this._j.stop_requested()) {
                return;
            }
            if (this._family.equals((Object)GLMModel.GLMParameters.Family.ordinal) || this._family.equals((Object)GLMModel.GLMParameters.Family.multinomial)) {
                this._eta = MemoryManager.malloc8d(this._nclass);
            }
            this._vcov = ((GAMModelOutput)this._m._output)._glm_vcov;
            if (this._vcov != null) {
                this._tmp = MemoryManager.malloc8d(this._vcov.length);
            }
            int numPredVals = this._nclass <= 1 ? 1 : this._nclass + 1;
            double[] predictVals = MemoryManager.malloc8d(numPredVals);
            float[] trueResponse = null;
            if (this._computeMetrics) {
                this._mb = this._m.makeMetricBuilder(this._predDomains);
                trueResponse = new float[1];
            }
            DataInfo.Row r2 = this._dinfo.newDenseRow();
            int chkLen = chks[0]._len;
            for (int rid = 0; rid < chkLen; ++rid) {
                this._dinfo.extractDenseRow(chks, rid, r2);
                this.processRow(r2, predictVals, nc, numPredVals);
                if (!this._computeMetrics || r2.response_bad) continue;
                trueResponse[0] = (float)r2.response[0];
                this._mb.perRow(predictVals, trueResponse, r2.weight, r2.offset, this._m);
            }
            if (this._j != null) {
                this._j.update(1L);
            }
        }

        private void processRow(DataInfo.Row r2, double[] ps, NewChunk[] preds, int ncols) {
            if (r2.predictors_bad) {
                Arrays.fill(ps, Double.NaN);
            } else if (r2.weight == 0.0) {
                Arrays.fill(ps, 0.0);
            }
            switch (this._family) {
                case multinomial: {
                    ps = this.scoreMultinomialRow(r2, r2.offset, ps);
                    break;
                }
                case ordinal: {
                    ps = this.scoreOrdinalRow(r2, r2.offset, ps);
                    break;
                }
                default: {
                    ps = this.scoreRow(r2, r2.offset, ps);
                }
            }
            if (this._generatePredictions) {
                for (int predCol = 0; predCol < ncols; ++predCol) {
                    preds[predCol].addNum(ps[predCol]);
                }
                if (this._vcov != null) {
                    preds[ncols].addNum(Math.sqrt(r2.innerProduct(r2.mtrxMul(this._vcov, this._tmp))));
                }
            }
        }

        public double[] scoreRow(DataInfo.Row r2, double offset, double[] preds) {
            double mu = ((GAMParameters)this._m._parms).linkInv(r2.innerProduct(this._coeffs) + offset);
            if (this._classifier2class) {
                preds[0] = mu >= this._defaultThreshold ? 1.0 : 0.0;
                preds[1] = 1.0 - mu;
                preds[2] = mu;
            } else {
                preds[0] = mu;
            }
            return preds;
        }

        public double[] scoreOrdinalRow(DataInfo.Row r2, double offset, double[] preds) {
            int cInd;
            double[][] bm = this._coeffs_multinomial;
            Arrays.fill(preds, 0.0);
            preds[0] = this._lastClass;
            double previousCDF = 0.0;
            for (cInd = 0; cInd < this._lastClass; ++cInd) {
                double eta = r2.innerProduct(bm[cInd]) + offset;
                double currCDF = 1.0 / (1.0 + Math.exp(-eta));
                preds[cInd + 1] = currCDF - previousCDF;
                previousCDF = currCDF;
                if (!(eta > 0.0)) continue;
                preds[0] = cInd;
                break;
            }
            for (cInd = (int)preds[0] + 1; cInd < this._lastClass; ++cInd) {
                double currCDF = 1.0 / (1.0 + Math.exp(-r2.innerProduct(bm[cInd]) + offset));
                preds[cInd + 1] = currCDF - previousCDF;
                previousCDF = currCDF;
            }
            preds[this._nclass] = 1.0 - previousCDF;
            return preds;
        }

        public double[] scoreMultinomialRow(DataInfo.Row r2, double offset, double[] preds) {
            int c2;
            double[] eta = this._eta;
            double[][] bm = this._coeffs_multinomial;
            double sumExp = 0.0;
            double maxRow = Double.NEGATIVE_INFINITY;
            for (c2 = 0; c2 < bm.length; ++c2) {
                eta[c2] = r2.innerProduct(bm[c2]) + offset;
                if (!(eta[c2] > maxRow)) continue;
                maxRow = eta[c2];
            }
            for (c2 = 0; c2 < bm.length; ++c2) {
                eta[c2] = Math.exp(eta[c2] - maxRow);
                sumExp += eta[c2];
            }
            sumExp = 1.0 / sumExp;
            for (c2 = 0; c2 < bm.length; ++c2) {
                preds[c2 + 1] = eta[c2] * sumExp;
            }
            preds[0] = ArrayUtils.maxIndex(eta);
            return preds;
        }

        @Override
        public void reduce(GAMScore other) {
            if (this._mb != null) {
                this._mb.reduce(other._mb);
            }
        }

        @Override
        protected void postGlobal() {
            if (this._mb != null) {
                this._mb.postGlobal();
            }
        }
    }

    public static class GAMModelOutput
    extends Model.Output {
        public String[] _coefficient_names_no_centering;
        public String[] _coefficient_names;
        public TwoDimTable _glm_model_summary;
        public ModelMetrics _glm_training_metrics;
        public ModelMetrics _glm_validation_metrics;
        public double _glm_dispersion;
        public double[] _glm_zvalues;
        public double[] _glm_pvalues;
        public double[][] _glm_vcov;
        public double[] _glm_stdErr;
        public double _glm_best_lamda_value;
        public TwoDimTable _glm_scoring_history;
        public TwoDimTable[] _glm_cv_scoring_history;
        public TwoDimTable _coefficients_table;
        public TwoDimTable _coefficients_table_no_centering;
        public TwoDimTable _standardized_coefficient_magnitudes;
        public TwoDimTable _variable_importances;
        public VarImp _varimp;
        public double[] _model_beta_no_centering;
        public double[] _standardized_model_beta_no_centering;
        public double[] _model_beta;
        public double[] _standardized_model_beta;
        public double[][] _model_beta_multinomial_no_centering;
        public double[][] _standardized_model_beta_multinomial_no_centering;
        public double[][] _model_beta_multinomial;
        public double[][] _standardized_model_beta_multinomial;
        public double _best_alpha;
        public double _best_lambda;
        public double _devianceValid = Double.NaN;
        public double _devianceTrain = Double.NaN;
        private double[] _zvalues;
        private double _dispersion;
        private boolean _dispersionEstimated;
        public String[][] _gamColNames;
        public double[][][] _zTranspose;
        public double[][][] _penaltyMatricesCenter;
        public double[][][] _penaltyMatrices;
        public double[][][] _binvD;
        public double[][][] _knots;
        int[][][] _allPolyBasisList;
        double[][][] _penaltyMatCS;
        double[][][] _zTransposeCS;
        public int[] _numKnots;
        public double[][][] _starT;
        public double[][] _gamColMeansRaw;
        public double[][] _oneOGamColStd;
        public double[] _penaltyScale;
        public Key<Frame> _gamTransformedTrainCenter;
        public DataInfo _dinfo;
        public String[] _responseDomains;
        public String _gam_transformed_center_key;
        final GLMModel.GLMParameters.Family _family;

        @Override
        public int nclasses() {
            if (this._family == GLMModel.GLMParameters.Family.multinomial || this._family == GLMModel.GLMParameters.Family.ordinal) {
                return super.nclasses();
            }
            if (GLMModel.GLMParameters.Family.binomial == this._family || GLMModel.GLMParameters.Family.quasibinomial == this._family || GLMModel.GLMParameters.Family.fractionalbinomial == this._family) {
                return 2;
            }
            return 1;
        }

        @Override
        public String[] classNames() {
            if (this._family == GLMModel.GLMParameters.Family.quasibinomial || this._family == GLMModel.GLMParameters.Family.binomial) {
                return this._responseDomains;
            }
            if (this._family == GLMModel.GLMParameters.Family.fractionalbinomial) {
                return BINOMIAL_CLASS_NAMES;
            }
            return super.classNames();
        }

        public GAMModelOutput(GAM b2, DataInfo dinfo) {
            super(b2, dinfo._adaptedFrame);
            this._dinfo = dinfo;
            this._domains = dinfo._adaptedFrame.domains();
            this._family = ((GAMParameters)b2._parms)._family;
            this._responseDomains = this._family.equals((Object)GLMModel.GLMParameters.Family.quasibinomial) ? ((VecUtils.CollectDoubleDomain)new VecUtils.CollectDoubleDomain(null, 2).doAll(dinfo._adaptedFrame.vec(((GAMParameters)b2._parms)._response_column))).stringDomain(dinfo._adaptedFrame.vec(((GAMParameters)b2._parms)._response_column).isInt()) : dinfo._adaptedFrame.lastVec().domain();
        }

        @Override
        public ModelCategory getModelCategory() {
            switch (this._family) {
                case quasibinomial: 
                case fractionalbinomial: 
                case binomial: {
                    return ModelCategory.Binomial;
                }
                case multinomial: {
                    return ModelCategory.Multinomial;
                }
                case ordinal: {
                    return ModelCategory.Ordinal;
                }
            }
            return ModelCategory.Regression;
        }

        public void copyMetrics(GAMModel gamModel, Frame train, boolean forTrain, ModelMetrics glmMetrics) {
            ModelMetrics tmpMetrics = glmMetrics.deepCloneWithDifferentModelAndFrame(gamModel, train);
            if (forTrain) {
                ((GAMModelOutput)gamModel._output)._training_metrics = tmpMetrics;
            } else {
                ((GAMModelOutput)gamModel._output)._validation_metrics = tmpMetrics;
            }
        }
    }

    public static class GAMParameters
    extends Model.Parameters {
        public boolean _standardize = false;
        public GLMModel.GLMParameters.Family _family = GLMModel.GLMParameters.Family.AUTO;
        public GLMModel.GLMParameters.Link _link = GLMModel.GLMParameters.Link.family_default;
        public GLMModel.GLMParameters.Solver _solver = GLMModel.GLMParameters.Solver.AUTO;
        public double _tweedie_variance_power;
        public double _tweedie_link_power;
        public double _theta;
        public double[] _alpha;
        public double[] _lambda;
        public double[] _startval;
        public Serializable _missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.MeanImputation;
        public boolean _lambda_search = false;
        public boolean _use_all_factor_levels = false;
        public int _max_iterations = -1;
        public boolean _intercept = true;
        public double _beta_epsilon = 1.0E-4;
        public double _objective_epsilon = -1.0;
        public double _obj_reg = -1.0;
        public boolean _compute_p_values = false;
        public boolean _scale_tp_penalty_mat = false;
        public boolean _standardize_tp_gam_cols = false;
        public String[] _interactions = null;
        public StringPair[] _interaction_pairs = null;
        public Key<Frame> _plug_values = null;
        public int _max_active_predictors = -1;
        public boolean _generate_scoring_history = false;
        public int[] _num_knots;
        public int[] _spline_orders;
        public int[] _spline_orders_sorted;
        public int[] _num_knots_sorted;
        public int[] _num_knots_tp;
        public String[] _knot_ids;
        public String[][] _gam_columns;
        public String[][] _gam_columns_sorted;
        public int[] _gamPredSize;
        public int[] _m;
        public int[] _M;
        public int[] _bs;
        public int[] _bs_sorted;
        public double[] _scale;
        public double[] _scale_sorted;
        public boolean _saveZMatrix = false;
        public boolean _keep_gam_cols = false;
        public boolean _savePenaltyMat = false;
        public double _prior = -1.0;
        public boolean _cold_start = false;
        public int _nlambdas = -1;
        public boolean _non_negative = false;
        public boolean _remove_collinear_columns = false;
        public double _gradient_epsilon = -1.0;
        public boolean _early_stopping = true;
        public Key<Frame> _beta_constraints = null;
        public double _lambda_min_ratio = -1.0;
        public boolean _betaConstraintsOff = false;
        public long _seed = -1L;

        @Override
        public String algoName() {
            return "GAM";
        }

        @Override
        public String fullName() {
            return "Generalized Additive Model";
        }

        @Override
        public String javaName() {
            return GAMModel.class.getName();
        }

        @Override
        public long progressUnits() {
            return 1L;
        }

        public Model.InteractionSpec interactionSpec() {
            return Model.InteractionSpec.create(this._interactions, this._interaction_pairs);
        }

        public GLMModel.GLMParameters.MissingValuesHandling missingValuesHandling() {
            if (this._missing_values_handling instanceof GLMModel.GLMParameters.MissingValuesHandling) {
                return (GLMModel.GLMParameters.MissingValuesHandling)((Object)this._missing_values_handling);
            }
            assert (this._missing_values_handling instanceof DeepLearningModel.DeepLearningParameters.MissingValuesHandling);
            switch ((DeepLearningModel.DeepLearningParameters.MissingValuesHandling)((Object)this._missing_values_handling)) {
                case MeanImputation: {
                    return GLMModel.GLMParameters.MissingValuesHandling.MeanImputation;
                }
                case Skip: {
                    return GLMModel.GLMParameters.MissingValuesHandling.Skip;
                }
            }
            throw new IllegalStateException("Unsupported missing values handling value: " + this._missing_values_handling);
        }

        public DataInfo.Imputer makeImputer() {
            if (this.missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.PlugValues) {
                if (this._plug_values == null || this._plug_values.get() == null) {
                    throw new IllegalStateException("Plug values frame needs to be specified when Missing Value Handling = PlugValues.");
                }
                return new GLM.PlugValuesImputer(this._plug_values.get());
            }
            return new DataInfo.MeanImputer();
        }

        public double linkInv(double x2) {
            switch (this._link) {
                case identity: {
                    return x2;
                }
                case ologlog: {
                    return 1.0 - Math.exp(-1.0 * Math.exp(x2));
                }
                case ologit: 
                case logit: {
                    return 1.0 / (Math.exp(-x2) + 1.0);
                }
                case log: {
                    return Math.exp(x2);
                }
                case inverse: {
                    double xx = x2 < 0.0 ? Math.min(-1.0E-5, x2) : Math.max(1.0E-5, x2);
                    return 1.0 / xx;
                }
                case tweedie: {
                    return this._tweedie_link_power == 0.0 ? Math.max(2.0E-16, Math.exp(x2)) : Math.pow(x2, 1.0 / this._tweedie_link_power);
                }
            }
            throw new RuntimeException("unexpected link function  " + this._link.toString());
        }

        @Override
        public void setDistributionFamily(DistributionFamily distributionFamily) {
            this._family = DistributionUtils.distributionToFamily(distributionFamily);
            this._link = GLMModel.GLMParameters.Link.family_default;
        }

        @Override
        public DistributionFamily getDistributionFamily() {
            return DistributionUtils.familyToDistribution(this._family);
        }
    }
}

