/*
 * Decompiled with CFR 0.152.
 */
package hex.modelselection;

import hex.DataInfo;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelBuilderHelper;
import hex.ModelCategory;
import hex.genmodel.utils.MathUtils;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.modelselection.ModelSelectionModel;
import hex.modelselection.ModelSelectionUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import water.DKV;
import water.H2O;
import water.Key;
import water.Keyed;
import water.Scope;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;

public class ModelSelection
extends ModelBuilder<ModelSelectionModel, ModelSelectionModel.ModelSelectionParameters, ModelSelectionModel.ModelSelectionModelOutput> {
    public String[][] _bestModelPredictors;
    public double[] _bestR2Values;
    DataInfo _dinfo;
    public int _numModelBuilt;
    public int _numPredictors;
    public String[] _predictorNames;
    public int _glmNFolds = 0;
    Model.Parameters.FoldAssignmentScheme _foldAssignment = null;
    String _foldColumn = null;

    public ModelSelection(boolean startup_once) {
        super((Model.Parameters)new ModelSelectionModel.ModelSelectionParameters(), startup_once);
    }

    public ModelSelection(ModelSelectionModel.ModelSelectionParameters parms) {
        super((Model.Parameters)parms);
        this.init(false);
    }

    public ModelSelection(ModelSelectionModel.ModelSelectionParameters parms, Key<ModelSelectionModel> key) {
        super((Model.Parameters)parms, key);
        this.init(false);
    }

    protected int nModelsInParallel(int folds) {
        return this.nModelsInParallel(1, 2);
    }

    protected ModelSelectionDriver trainModelImpl() {
        return new ModelSelectionDriver();
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression};
    }

    public boolean isSupervised() {
        return true;
    }

    public boolean haveMojo() {
        return false;
    }

    public boolean havePojo() {
        return false;
    }

    public void init(boolean expensive) {
        if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._nfolds > 0 || ((ModelSelectionModel.ModelSelectionParameters)this._parms)._fold_column != null) {
            this._glmNFolds = ((ModelSelectionModel.ModelSelectionParameters)this._parms)._nfolds;
            if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._fold_assignment != null) {
                this._foldAssignment = ((ModelSelectionModel.ModelSelectionParameters)this._parms)._fold_assignment;
                ((ModelSelectionModel.ModelSelectionParameters)this._parms)._fold_assignment = null;
            }
            if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._fold_column != null) {
                this._foldColumn = ((ModelSelectionModel.ModelSelectionParameters)this._parms)._fold_column;
                ((ModelSelectionModel.ModelSelectionParameters)this._parms)._fold_column = null;
            }
            ((ModelSelectionModel.ModelSelectionParameters)this._parms)._nfolds = 0;
        }
        super.init(expensive);
        if (expensive) {
            this.initModelSelectionParameters();
            this.initModelParameters();
        }
    }

    private void initModelParameters() {
        this._numModelBuilt = ModelSelectionUtils.calculateModelNumber(this._numPredictors, ((ModelSelectionModel.ModelSelectionParameters)this._parms)._max_predictor_number);
        this._bestR2Values = new double[((ModelSelectionModel.ModelSelectionParameters)this._parms)._max_predictor_number];
        this._bestModelPredictors = new String[((ModelSelectionModel.ModelSelectionParameters)this._parms)._max_predictor_number][];
    }

    private void initModelSelectionParameters() {
        if (this.nclasses() > 1) {
            this.error("response", "ModelSelection only works with regression.");
        }
        if (!GLMModel.GLMParameters.Family.AUTO.equals((Object)((ModelSelectionModel.ModelSelectionParameters)this._parms)._family) && !GLMModel.GLMParameters.Family.gaussian.equals((Object)((ModelSelectionModel.ModelSelectionParameters)this._parms)._family)) {
            this.error("_family", "ModelSelection only supports Gaussian family");
        }
        if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._nparallelism < 0) {
            this.error("nparallelism", "must be >= 0.");
        }
        if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._nparallelism == 0) {
            ((ModelSelectionModel.ModelSelectionParameters)this._parms)._nparallelism = H2O.NUMCPUS;
        }
        this._predictorNames = ModelSelectionUtils.extractPredictorNames((ModelSelectionModel.ModelSelectionParameters)this._parms, this._dinfo, this._foldColumn);
        this._numPredictors = this._predictorNames.length;
        if (((ModelSelectionModel.ModelSelectionParameters)this._parms)._max_predictor_number < 1 || ((ModelSelectionModel.ModelSelectionParameters)this._parms)._max_predictor_number > this._numPredictors) {
            this.error("max_predictor_number", "max_predictor_number must exceed 0 and be no greater than the number of predictors of the training frame.");
        }
    }

    public static GLMModel buildExtractBestR2Model(Frame[] trainingFrames, ModelSelectionModel.ModelSelectionParameters parms, int glmNFolds, String foldColumn, Model.Parameters.FoldAssignmentScheme foldAssignment) {
        GLMModel.GLMParameters[] trainingParams = ModelSelectionUtils.generateGLMParameters(trainingFrames, parms, glmNFolds, foldColumn, foldAssignment);
        ModelBuilder[] glmBuilder = ModelSelectionUtils.buildGLMBuilders(trainingParams);
        GLM[] glmResults = (GLM[])ModelBuilderHelper.trainModelsParallel((ModelBuilder[])glmBuilder, (int)parms._nparallelism);
        Stream.of(glmResults).map(ModelBuilder::get).forEach(Scope::track_generic);
        return ModelSelectionUtils.findBestModel(glmResults);
    }

    public static GLMModel forwardStep(List<Integer> currSubsetIndices, List<String> coefNames, int predPos, List<Integer> validSubsets, ModelSelectionModel.ModelSelectionParameters parms, String foldColumn, int glmNFolds, Model.Parameters.FoldAssignmentScheme foldAssignment) {
        String[] predictorNames = (String[])coefNames.stream().toArray(String[]::new);
        Frame[] trainingFrames = ModelSelectionUtils.generateMaxRTrainingFrames(parms, predictorNames, foldColumn, currSubsetIndices, predPos, validSubsets);
        GLMModel bestModel = ModelSelection.buildExtractBestR2Model(trainingFrames, parms, glmNFolds, foldColumn, foldAssignment);
        List<String> coefUsed = ModelSelectionUtils.extraModelColumnNames(coefNames, bestModel);
        for (int predIndex = coefUsed.size() - 1; predIndex >= 0; --predIndex) {
            int index = coefNames.indexOf(coefUsed.get(predIndex));
            if (currSubsetIndices.contains(index)) continue;
            currSubsetIndices.add(predPos, index);
            break;
        }
        ModelSelectionUtils.removeTrainingFrames(trainingFrames);
        return bestModel;
    }

    public static GLMModel replacement(List<Integer> currSubsetIndices, List<String> coefNames, double bestR2, ModelSelectionModel.ModelSelectionParameters parms, int glmNFolds, String foldColumn, List<Integer> validSubset, Model.Parameters.FoldAssignmentScheme foldAssignment) {
        int currSubsetSize = currSubsetIndices.size();
        int lastBestR2PosIndex = -1;
        GLMModel bestR2Model = null;
        GLMModel[] bestR2Models = new GLMModel[currSubsetSize];
        int[] r2PredPosIndex = new int[currSubsetSize];
        int[][] subsetsCombo = new int[currSubsetSize][];
        List<Integer> originalSubset = new ArrayList<Integer>(currSubsetSize);
        while (true) {
            for (int index = 0; index < currSubsetSize; ++index) {
                if (index == lastBestR2PosIndex) continue;
                ArrayList<Integer> oneLessSubset = new ArrayList<Integer>(currSubsetIndices);
                int predIndexRemoved = oneLessSubset.remove(index);
                validSubset.add(predIndexRemoved);
                bestR2Models[index] = ModelSelection.forwardStep(oneLessSubset, coefNames, index, validSubset, parms, foldColumn, glmNFolds, foldAssignment);
                subsetsCombo[index] = oneLessSubset.stream().mapToInt(i -> i).toArray();
                r2PredPosIndex[index] = index;
                validSubset.remove(validSubset.indexOf(predIndexRemoved));
            }
            int bestR2ModelIndex = ModelSelectionUtils.findBestR2Model(bestR2, bestR2Models);
            if (bestR2ModelIndex < 0) break;
            bestR2Model = bestR2Models[bestR2ModelIndex];
            bestR2 = bestR2Model.r2();
            currSubsetIndices = Arrays.stream(subsetsCombo[bestR2ModelIndex]).boxed().collect(Collectors.toList());
            lastBestR2PosIndex = r2PredPosIndex[bestR2ModelIndex];
            ModelSelectionUtils.updateValidSubset(validSubset, originalSubset, currSubsetIndices);
            originalSubset = currSubsetIndices;
        }
        return bestR2Model;
    }

    public class ModelSelectionDriver
    extends ModelBuilder.Driver {
        public ModelSelectionDriver() {
            super((ModelBuilder)ModelSelection.this);
        }

        public final void buildModel() {
            ModelSelectionModel model = null;
            try {
                model = new ModelSelectionModel((Key<ModelSelectionModel>)ModelSelection.this.dest(), (ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms, new ModelSelectionModel.ModelSelectionModelOutput(ModelSelection.this, ModelSelection.this._dinfo));
                model.write_lock(ModelSelection.this._job);
                ((ModelSelectionModel.ModelSelectionModelOutput)model._output)._best_model_ids = new Key[((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._max_predictor_number];
                ((ModelSelectionModel.ModelSelectionModelOutput)model._output)._best_r2_values = new double[((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._max_predictor_number];
                ((ModelSelectionModel.ModelSelectionModelOutput)model._output)._best_model_predictors = new String[((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._max_predictor_number][];
                ((ModelSelectionModel.ModelSelectionModelOutput)model._output)._coefficient_names = new String[((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._max_predictor_number][];
                if (ModelSelectionModel.ModelSelectionParameters.Mode.allsubsets.equals((Object)((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._mode)) {
                    this.buildAllSubsetsModels(model);
                } else if (ModelSelectionModel.ModelSelectionParameters.Mode.maxr.equals((Object)((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._mode)) {
                    this.buildMaxRModels(model);
                }
                ModelSelection.this._job.update(0L, "Completed GLM model building.  Extracting results now.");
                model.update(ModelSelection.this._job);
                ((ModelSelectionModel.ModelSelectionModelOutput)model._output).generateSummary();
            }
            finally {
                model.update(ModelSelection.this._job);
                model.unlock(ModelSelection.this._job);
            }
        }

        void buildMaxRModels(ModelSelectionModel model) {
            ArrayList<Integer> currSubsetIndices = new ArrayList<Integer>();
            ArrayList<String> coefNames = new ArrayList<String>(Arrays.asList(ModelSelection.this._predictorNames));
            List<Integer> validSubset = IntStream.rangeClosed(0, coefNames.size() - 1).boxed().collect(Collectors.toList());
            for (int predNum = 1; predNum <= ((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._max_predictor_number; ++predNum) {
                GLMModel currBestR2Model;
                GLMModel bestR2Model = ModelSelection.forwardStep(currSubsetIndices, coefNames, predNum - 1, validSubset, (ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms, ModelSelection.this._foldColumn, ModelSelection.this._glmNFolds, ModelSelection.this._foldAssignment);
                validSubset.removeAll(currSubsetIndices);
                ModelSelection.this._job.update((long)predNum, "Finished building all models with " + predNum + " predictors.");
                if (predNum < ModelSelection.this._numPredictors && (currBestR2Model = ModelSelection.replacement(currSubsetIndices, coefNames, bestR2Model.r2(), (ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms, ModelSelection.this._glmNFolds, ModelSelection.this._foldColumn, validSubset, ModelSelection.this._foldAssignment)) != null) {
                    bestR2Model = currBestR2Model;
                }
                Scope.untrack((Key[])new Key[]{bestR2Model.getKey()});
                DKV.put((Keyed)bestR2Model);
                ((ModelSelectionModel.ModelSelectionModelOutput)model._output).updateBestModels(bestR2Model, predNum - 1);
            }
        }

        void buildAllSubsetsModels(ModelSelectionModel model) {
            for (int predNum = 1; predNum <= ((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms)._max_predictor_number; ++predNum) {
                int numModels = MathUtils.combinatorial((int)ModelSelection.this._numPredictors, (int)predNum);
                Frame[] trainingFrames = ModelSelectionUtils.generateTrainingFrames((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms, predNum, ModelSelection.this._predictorNames, numModels, ModelSelection.this._foldColumn);
                GLMModel bestModel = ModelSelection.buildExtractBestR2Model(trainingFrames, (ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms, ModelSelection.this._glmNFolds, ModelSelection.this._foldColumn, ModelSelection.this._foldAssignment);
                Scope.untrack((Key[])new Key[]{bestModel.getKey()});
                DKV.put((Keyed)bestModel);
                int index = predNum - 1;
                ((ModelSelectionModel.ModelSelectionModelOutput)model._output).updateBestModels(bestModel, index);
                ModelSelectionUtils.removeTrainingFrames(trainingFrames);
                ModelSelection.this._job.update((long)predNum, "Finished building all models with " + predNum + " predictors.");
            }
        }

        public void computeImpl() {
            ModelSelection.this._dinfo = new DataInfo((Frame)ModelSelection.this._train.clone(), ModelSelection.this._valid, 1, false, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, ((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms).missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.Skip, ((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms).imputeMissing(), ((ModelSelectionModel.ModelSelectionParameters)ModelSelection.this._parms).makeImputer(), false, ModelSelection.this.hasWeightCol(), ModelSelection.this.hasOffsetCol(), ModelSelection.this.hasFoldCol(), null);
            ModelSelection.this.init(true);
            if (ModelSelection.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder((ModelBuilder)ModelSelection.this);
            }
            ModelSelection.this._job.update(0L, "finished init and ready to build models");
            this.buildModel();
        }
    }
}

