/*
 * Decompiled with CFR 0.152.
 */
package hex.modelselection;

import hex.DataInfo;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelBuilderHelper;
import hex.ModelCategory;
import hex.genmodel.utils.MathUtils;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.modelselection.ModelSelectionModel;
import hex.modelselection.ModelSelectionUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import water.DKV;
import water.H2O;
import water.Key;
import water.Lockable;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;

public class ModelSelection
extends ModelBuilder<ModelSelectionModel, ModelSelectionModel.ModelSelectionParameters, ModelSelectionModel.ModelSelectionModelOutput> {
    public String[][] _bestModelPredictors;
    public double[] _bestR2Values;
    DataInfo _dinfo;
    public int _numPredictors;
    public String[] _predictorNames;
    public int _glmNFolds = 0;
    Model.Parameters.FoldAssignmentScheme _foldAssignment = null;
    String _foldColumn = null;

    public ModelSelection(boolean startup_once) {
        super(new ModelSelectionModel.ModelSelectionParameters(), startup_once);
    }

    public ModelSelection(ModelSelectionModel.ModelSelectionParameters parms) {
        super(parms);
        this.init(false);
    }

    public ModelSelection(ModelSelectionModel.ModelSelectionParameters parms, Key<ModelSelectionModel> key) {
        super(parms, key);
        this.init(false);
    }

    @Override
    protected int nModelsInParallel(int folds) {
        return this.nModelsInParallel(1, 2);
    }

    @Override
    protected ModelSelectionDriver trainModelImpl() {
        return new ModelSelectionDriver();
    }

    @Override
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression};
    }

    @Override
    public boolean isSupervised() {
        return true;
    }

    @Override
    public boolean haveMojo() {
        return false;
    }

    @Override
    public boolean havePojo() {
        return false;
    }

    @Override
    public void init(boolean expensive) {
        if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._nfolds > 0 || ((ModelSelectionModel.ModelSelectionParameters)this._parms)._fold_column != null) {
            if (ModelSelectionModel.ModelSelectionParameters.Mode.backward.equals((Object)((ModelSelectionModel.ModelSelectionParameters)this._parms)._mode)) {
                this.error("nfolds/fold_column", "cross-validation is not supported for backward selection.");
            } else {
                this._glmNFolds = ((ModelSelectionModel.ModelSelectionParameters)this._parms)._nfolds;
                if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._fold_assignment != null) {
                    this._foldAssignment = ((ModelSelectionModel.ModelSelectionParameters)this._parms)._fold_assignment;
                    ((ModelSelectionModel.ModelSelectionParameters)this._parms)._fold_assignment = null;
                }
                if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._fold_column != null) {
                    this._foldColumn = ((ModelSelectionModel.ModelSelectionParameters)this._parms)._fold_column;
                    ((ModelSelectionModel.ModelSelectionParameters)this._parms)._fold_column = null;
                }
                ((ModelSelectionModel.ModelSelectionParameters)this._parms)._nfolds = 0;
            }
        }
        super.init(expensive);
        if (this.error_count() > 0) {
            return;
        }
        if (expensive) {
            this.initModelSelectionParameters();
            if (this.error_count() > 0) {
                return;
            }
            this.initModelParameters();
        }
    }

    private void initModelParameters() {
        if (!ModelSelectionModel.ModelSelectionParameters.Mode.backward.equals((Object)((ModelSelectionModel.ModelSelectionParameters)this._parms)._mode)) {
            this._bestR2Values = new double[((ModelSelectionModel.ModelSelectionParameters)this._parms)._max_predictor_number];
            this._bestModelPredictors = new String[((ModelSelectionModel.ModelSelectionParameters)this._parms)._max_predictor_number][];
        }
    }

    private void initModelSelectionParameters() {
        this._predictorNames = ModelSelectionUtils.extractPredictorNames((ModelSelectionModel.ModelSelectionParameters)this._parms, this._dinfo, this._foldColumn);
        this._numPredictors = this._predictorNames.length;
        if (ModelSelectionModel.ModelSelectionParameters.Mode.maxr.equals((Object)((ModelSelectionModel.ModelSelectionParameters)this._parms)._mode) || ModelSelectionModel.ModelSelectionParameters.Mode.allsubsets.equals((Object)((ModelSelectionModel.ModelSelectionParameters)this._parms)._mode)) {
            if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._lambda == null && !((ModelSelectionModel.ModelSelectionParameters)this._parms)._lambda_search && ((ModelSelectionModel.ModelSelectionParameters)this._parms)._alpha == null) {
                ((ModelSelectionModel.ModelSelectionParameters)this._parms)._lambda = new double[]{0.0};
            }
            if (this.nclasses() > 1) {
                this.error("response", "'allsubsets' and 'maxr' only works with regression.");
            }
            if (!GLMModel.GLMParameters.Family.AUTO.equals((Object)((ModelSelectionModel.ModelSelectionParameters)this._parms)._family) && !GLMModel.GLMParameters.Family.gaussian.equals((Object)((ModelSelectionModel.ModelSelectionParameters)this._parms)._family)) {
                this.error("_family", "ModelSelection only supports Gaussian family for 'allsubset' and 'maxr' mode.");
            }
            if (GLMModel.GLMParameters.Family.AUTO.equals((Object)((ModelSelectionModel.ModelSelectionParameters)this._parms)._family)) {
                ((ModelSelectionModel.ModelSelectionParameters)this._parms)._family = GLMModel.GLMParameters.Family.gaussian;
            }
            if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._max_predictor_number < 1 || ((ModelSelectionModel.ModelSelectionParameters)this._parms)._max_predictor_number > this._numPredictors) {
                this.error("max_predictor_number", "max_predictor_number must exceed 0 and be no greater than the number of predictors of the training frame.");
            }
        } else {
            ((ModelSelectionModel.ModelSelectionParameters)this._parms)._compute_p_values = true;
            if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._valid != null) {
                this.error("validation_frame", " is not supported for ModelSelection mode='backward'");
            }
            if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._lambda_search) {
                this.error("lambda_search", "backward selection does not support lambda_search.");
            }
            if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._lambda != null) {
                if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._lambda.length > 1) {
                    this.error("lambda", "if set must be set to 0 and cannot be an array or more than length one for backward selection.");
                }
                if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._lambda[0] != 0.0) {
                    this.error("lambda", "must be set to 0 for backward selection");
                }
            } else {
                ((ModelSelectionModel.ModelSelectionParameters)this._parms)._lambda = new double[]{0.0};
            }
            if (GLMModel.GLMParameters.Family.multinomial.equals((Object)((ModelSelectionModel.ModelSelectionParameters)this._parms)._family) || GLMModel.GLMParameters.Family.ordinal.equals((Object)((ModelSelectionModel.ModelSelectionParameters)this._parms)._family)) {
                this.error("family", "backward selection does not support multinomial or ordinal");
            }
            if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._min_predictor_number <= 0) {
                this.error("min_predictor_number", "must be >= 1.");
            }
            if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._min_predictor_number > this._numPredictors) {
                this.error("min_predictor_number", "cannot exceed the total number of predictors (" + this._numPredictors + ")in the dataset.");
            }
        }
        if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._nparallelism < 0) {
            this.error("nparallelism", "must be >= 0.");
        }
        if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._nparallelism == 0) {
            ((ModelSelectionModel.ModelSelectionParameters)this._parms)._nparallelism = H2O.NUMCPUS;
        }
    }

    public static GLMModel buildExtractBestR2Model(Frame[] trainingFrames, ModelSelectionModel.ModelSelectionParameters parms, int glmNFolds, String foldColumn, Model.Parameters.FoldAssignmentScheme foldAssignment) {
        GLMModel.GLMParameters[] trainingParams = ModelSelectionUtils.generateGLMParameters(trainingFrames, parms, glmNFolds, foldColumn, foldAssignment);
        ModelBuilder[] glmBuilder = ModelSelectionUtils.buildGLMBuilders(trainingParams);
        GLM[] glmResults = (GLM[])ModelBuilderHelper.trainModelsParallel((ModelBuilder[])glmBuilder, (int)parms._nparallelism);
        return ModelSelectionUtils.findBestModel(glmResults);
    }

    public static GLMModel forwardStep(List<Integer> currSubsetIndices, List<String> coefNames, int predPos, List<Integer> validSubsets, ModelSelectionModel.ModelSelectionParameters parms, String foldColumn, int glmNFolds, Model.Parameters.FoldAssignmentScheme foldAssignment, Set<BitSet> usedCombo) {
        String[] predictorNames = (String[])coefNames.stream().toArray(String[]::new);
        Frame[] trainingFrames = ModelSelectionUtils.generateMaxRTrainingFrames(parms, predictorNames, foldColumn, currSubsetIndices, predPos, validSubsets, usedCombo);
        GLMModel bestModel = ModelSelection.buildExtractBestR2Model(trainingFrames, parms, glmNFolds, foldColumn, foldAssignment);
        List<String> coefUsed = ModelSelectionUtils.extraModelColumnNames(coefNames, bestModel);
        for (int predIndex = coefUsed.size() - 1; predIndex >= 0; --predIndex) {
            int index = coefNames.indexOf(coefUsed.get(predIndex));
            if (currSubsetIndices.contains(index)) continue;
            currSubsetIndices.add(predPos, index);
            break;
        }
        ModelSelectionUtils.removeTrainingFrames(trainingFrames);
        return bestModel;
    }

    public static GLMModel forwardStep(List<Integer> currSubsetIndices, List<String> coefNames, int predPos, List<Integer> validSubsets, ModelSelectionModel.ModelSelectionParameters parms, String foldColumn, int glmNFolds, Model.Parameters.FoldAssignmentScheme foldAssignment) {
        return ModelSelection.forwardStep(currSubsetIndices, coefNames, predPos, validSubsets, parms, foldColumn, glmNFolds, foldAssignment, null);
    }

    public static GLMModel replacement(List<Integer> currSubsetIndices, List<String> coefNames, double bestR2, ModelSelectionModel.ModelSelectionParameters parms, int glmNFolds, String foldColumn, List<Integer> validSubset, Model.Parameters.FoldAssignmentScheme foldAssignment) {
        int currSubsetSize = currSubsetIndices.size();
        Object[] sortedCurrSubset = currSubsetIndices.toArray(new Integer[0]);
        Arrays.sort(sortedCurrSubset);
        HashSet<BitSet> usedCombos = new HashSet<BitSet>();
        usedCombos.add(ModelSelectionUtils.setBitSet(currSubsetIndices.stream().mapToInt(i2 -> i2).toArray(), coefNames.size()));
        int lastBestR2PosIndex = -1;
        GLMModel bestR2Model = null;
        GLMModel[] bestR2Models = new GLMModel[currSubsetSize];
        int[] r2PredPosIndex = new int[currSubsetSize];
        int[][] subsetsCombo = new int[currSubsetSize][];
        List<Integer> originalSubset = new ArrayList<Integer>(currSubsetSize);
        while (true) {
            for (int index = 0; index < currSubsetSize; ++index) {
                if (index == lastBestR2PosIndex) continue;
                ArrayList<Integer> oneLessSubset = new ArrayList<Integer>(currSubsetIndices);
                int predIndexRemoved = oneLessSubset.remove(index);
                validSubset.add(predIndexRemoved);
                bestR2Models[index] = ModelSelection.forwardStep(oneLessSubset, coefNames, index, validSubset, parms, foldColumn, glmNFolds, foldAssignment, usedCombos);
                subsetsCombo[index] = oneLessSubset.stream().mapToInt(i2 -> i2).toArray();
                r2PredPosIndex[index] = index;
                validSubset.remove(validSubset.indexOf(predIndexRemoved));
            }
            int bestR2ModelIndex = ModelSelectionUtils.findBestR2Model(bestR2, bestR2Models);
            if (bestR2ModelIndex < 0) break;
            bestR2Model = bestR2Models[bestR2ModelIndex];
            bestR2 = bestR2Model.r2();
            currSubsetIndices = Arrays.stream(subsetsCombo[bestR2ModelIndex]).boxed().collect(Collectors.toList());
            lastBestR2PosIndex = r2PredPosIndex[bestR2ModelIndex];
            ModelSelectionUtils.updateValidSubset(validSubset, originalSubset, currSubsetIndices);
            originalSubset = currSubsetIndices;
        }
        return bestR2Model;
    }

    public class ModelSelectionDriver
    extends ModelBuilder.Driver {
        public ModelSelectionDriver() {
            super(ModelSelection.this);
        }

        public final void buildModel() {
            Lockable model = null;
            try {
                int numModelBuilt = 0;
                model = new ModelSelectionModel(ModelSelection.this.dest(), (ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms, new ModelSelectionModel.ModelSelectionModelOutput(ModelSelection.this, ModelSelection.this._dinfo));
                model.write_lock(ModelSelection.this._job);
                if (ModelSelectionModel.ModelSelectionParameters.Mode.backward.equals((Object)((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._mode)) {
                    ((ModelSelectionModel.ModelSelectionModelOutput)((ModelSelectionModel)model)._output)._best_model_ids = new Key[ModelSelection.this._numPredictors];
                    ((ModelSelectionModel.ModelSelectionModelOutput)((ModelSelectionModel)model)._output)._coef_p_values = new double[ModelSelection.this._numPredictors][];
                    ((ModelSelectionModel.ModelSelectionModelOutput)((ModelSelectionModel)model)._output)._z_values = new double[ModelSelection.this._numPredictors][];
                    ((ModelSelectionModel.ModelSelectionModelOutput)((ModelSelectionModel)model)._output)._best_model_predictors = new String[ModelSelection.this._numPredictors][];
                    ((ModelSelectionModel.ModelSelectionModelOutput)((ModelSelectionModel)model)._output)._coefficient_names = new String[ModelSelection.this._numPredictors][];
                } else {
                    ((ModelSelectionModel.ModelSelectionModelOutput)((ModelSelectionModel)model)._output)._best_model_ids = new Key[((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._max_predictor_number];
                    ((ModelSelectionModel.ModelSelectionModelOutput)((ModelSelectionModel)model)._output)._best_r2_values = new double[((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._max_predictor_number];
                    ((ModelSelectionModel.ModelSelectionModelOutput)((ModelSelectionModel)model)._output)._best_model_predictors = new String[((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._max_predictor_number][];
                    ((ModelSelectionModel.ModelSelectionModelOutput)((ModelSelectionModel)model)._output)._coefficient_names = new String[((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._max_predictor_number][];
                }
                if (ModelSelectionModel.ModelSelectionParameters.Mode.allsubsets.equals((Object)((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._mode)) {
                    this.buildAllSubsetsModels((ModelSelectionModel)model);
                } else if (ModelSelectionModel.ModelSelectionParameters.Mode.maxr.equals((Object)((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._mode)) {
                    this.buildMaxRModels((ModelSelectionModel)model);
                } else if (ModelSelectionModel.ModelSelectionParameters.Mode.backward.equals((Object)((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._mode)) {
                    numModelBuilt = this.buildBackwardModels((ModelSelectionModel)model);
                }
                ModelSelection.this._job.update(0L, "Completed GLM model building.  Extracting results now.");
                model.update(ModelSelection.this._job);
                if (ModelSelectionModel.ModelSelectionParameters.Mode.backward.equals((Object)((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._mode)) {
                    ((ModelSelectionModel.ModelSelectionModelOutput)((ModelSelectionModel)model)._output).shrinkArrays(numModelBuilt);
                    ((ModelSelectionModel.ModelSelectionModelOutput)((ModelSelectionModel)model)._output).generateSummary(numModelBuilt);
                } else {
                    ((ModelSelectionModel.ModelSelectionModelOutput)((ModelSelectionModel)model)._output).generateSummary();
                }
            }
            finally {
                model.update(ModelSelection.this._job);
                model.unlock(ModelSelection.this._job);
            }
        }

        void buildMaxRModels(ModelSelectionModel model) {
            ArrayList<Integer> currSubsetIndices = new ArrayList<Integer>();
            ArrayList<String> coefNames = new ArrayList<String>(Arrays.asList(ModelSelection.this._predictorNames));
            List<Integer> validSubset = IntStream.rangeClosed(0, coefNames.size() - 1).boxed().collect(Collectors.toList());
            for (int predNum = 1; predNum <= ((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._max_predictor_number; ++predNum) {
                GLMModel currBestR2Model;
                GLMModel bestR2Model = ModelSelection.forwardStep(currSubsetIndices, coefNames, predNum - 1, validSubset, (ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms, ModelSelection.this._foldColumn, ModelSelection.this._glmNFolds, ModelSelection.this._foldAssignment);
                validSubset.removeAll(currSubsetIndices);
                ModelSelection.this._job.update(predNum, "Finished building all models with " + predNum + " predictors.");
                if (predNum < ModelSelection.this._numPredictors && predNum > 1 && (currBestR2Model = ModelSelection.replacement(currSubsetIndices, coefNames, bestR2Model.r2(), (ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms, ModelSelection.this._glmNFolds, ModelSelection.this._foldColumn, validSubset, ModelSelection.this._foldAssignment)) != null) {
                    bestR2Model.delete();
                    bestR2Model = currBestR2Model;
                }
                DKV.put(bestR2Model);
                ((ModelSelectionModel.ModelSelectionModelOutput)model._output).updateBestModels(bestR2Model, predNum - 1);
            }
        }

        private int buildBackwardModels(ModelSelectionModel model) {
            ArrayList<String> coefNames = new ArrayList<String>(Arrays.asList(ModelSelection.this._predictorNames));
            List<Integer> coefIndice = IntStream.rangeClosed(0, coefNames.size() - 1).boxed().collect(Collectors.toList());
            int numModelsBuilt = 0;
            String[] coefName = coefNames.toArray(new String[0]);
            for (int predNum = ModelSelection.this._numPredictors; predNum >= ((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._min_predictor_number; --predNum) {
                int modelIndex = predNum - 1;
                int[] coefInd = coefIndice.stream().mapToInt(Integer::intValue).toArray();
                Frame trainingFrame = ModelSelectionUtils.generateOneFrame(coefInd, (ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms, coefName, ModelSelection.this._foldColumn);
                DKV.put(trainingFrame);
                GLMModel.GLMParameters[] glmParam = ModelSelectionUtils.generateGLMParameters(new Frame[]{trainingFrame}, (ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms, ModelSelection.this._glmNFolds, ModelSelection.this._foldColumn, ModelSelection.this._foldAssignment);
                GLMModel glmModel = (GLMModel)new GLM(glmParam[0]).trainModel().get();
                DKV.put(glmModel);
                ((ModelSelectionModel.ModelSelectionModelOutput)model._output).extractPredictors4NextModel(glmModel, modelIndex, coefNames, coefIndice);
                ++numModelsBuilt;
                DKV.remove(trainingFrame._key);
                ModelSelection.this._job.update(predNum, "Finished building all models with " + predNum + " predictors.");
                if (((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._p_values_threshold > 0.0 && DoubleStream.of(((ModelSelectionModel.ModelSelectionModelOutput)model._output)._coef_p_values[modelIndex]).limit(((ModelSelectionModel.ModelSelectionModelOutput)model._output)._coef_p_values[modelIndex].length - 1).allMatch(x2 -> x2 <= ((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._p_values_threshold)) break;
            }
            return numModelsBuilt;
        }

        void buildAllSubsetsModels(ModelSelectionModel model) {
            for (int predNum = 1; predNum <= ((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._max_predictor_number; ++predNum) {
                int numModels = MathUtils.combinatorial(ModelSelection.this._numPredictors, predNum);
                Frame[] trainingFrames = ModelSelectionUtils.generateTrainingFrames((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms, predNum, ModelSelection.this._predictorNames, numModels, ModelSelection.this._foldColumn);
                GLMModel bestModel = ModelSelection.buildExtractBestR2Model(trainingFrames, (ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms, ModelSelection.this._glmNFolds, ModelSelection.this._foldColumn, ModelSelection.this._foldAssignment);
                DKV.put(bestModel);
                int index = predNum - 1;
                ((ModelSelectionModel.ModelSelectionModelOutput)model._output).updateBestModels(bestModel, index);
                ModelSelectionUtils.removeTrainingFrames(trainingFrames);
                ModelSelection.this._job.update(predNum, "Finished building all models with " + predNum + " predictors.");
            }
        }

        @Override
        public void computeImpl() {
            ModelSelection.this._dinfo = new DataInfo((Frame)ModelSelection.this._train.clone(), ModelSelection.this._valid, 1, false, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, ((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms).missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.Skip, ((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms).imputeMissing(), ((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms).makeImputer(), false, ModelSelection.this.hasWeightCol(), ModelSelection.this.hasOffsetCol(), ModelSelection.this.hasFoldCol(), null);
            ModelSelection.this.init(true);
            if (ModelSelection.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(ModelSelection.this);
            }
            ModelSelection.this._job.update(0L, "finished init and ready to build models");
            this.buildModel();
        }
    }
}

