/*
 * Decompiled with CFR 0.152.
 */
package hex.ensemble;

import hex.Distribution;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ensemble.Metalearner;
import hex.ensemble.Metalearners;
import hex.ensemble.StackedEnsembleModel;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMModel;
import hex.grid.Grid;
import hex.tree.drf.DRFModel;
import hex.util.DistributionUtils;
import java.io.Serializable;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;
import jsr166y.CountedCompleter;
import water.DKV;
import water.Iced;
import water.Job;
import water.Key;
import water.Keyed;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.ReflectionUtils;
import water.util.TwoDimTable;

public class StackedEnsemble
extends ModelBuilder<StackedEnsembleModel, StackedEnsembleModel.StackedEnsembleParameters, StackedEnsembleModel.StackedEnsembleOutput> {
    StackedEnsembleDriver _driver;
    protected StackedEnsembleModel _model;

    public StackedEnsemble(StackedEnsembleModel.StackedEnsembleParameters parms) {
        super((Model.Parameters)parms);
        this.init(false);
    }

    public StackedEnsemble(boolean startup_once) {
        super((Model.Parameters)new StackedEnsembleModel.StackedEnsembleParameters(), startup_once);
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    public boolean isSupervised() {
        return true;
    }

    protected void ignoreBadColumns(int npredictors, boolean expensive) {
        final HashSet<String> usedColumns = new HashSet<String>();
        for (Key<Model> k : ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models) {
            Model model = (Model)DKV.getGet(k);
            usedColumns.add(model._parms._response_column);
            usedColumns.addAll(Arrays.asList(model._parms.getNonPredictors()));
            if (model._output._origNames != null) {
                usedColumns.addAll(Arrays.asList(model._output._origNames));
                continue;
            }
            usedColumns.addAll(Arrays.asList(model._output._names));
        }
        usedColumns.addAll(Arrays.asList(((StackedEnsembleModel.StackedEnsembleParameters)this._parms).getNonPredictors()));
        new ModelBuilder.FilterCols(0){

            protected boolean filter(Vec v, String name) {
                return !usedColumns.contains(name);
            }
        }.doIt(this._train, "Dropping unused columns: ", expensive);
    }

    protected StackedEnsembleDriver trainModelImpl() {
        this._driver = ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._blending == null ? new StackedEnsembleCVStackingDriver() : new StackedEnsembleBlendingDriver();
        return this._driver;
    }

    public boolean haveMojo() {
        return true;
    }

    public int nclasses() {
        if (((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters != null) {
            DistributionFamily distribution = ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters.getDistributionFamily();
            if (Arrays.asList(DistributionFamily.multinomial, DistributionFamily.ordinal, Model.Parameters.FoldAssignmentScheme.AUTO).contains(distribution)) {
                return this._nclass;
            }
            if (Arrays.asList(DistributionFamily.bernoulli, DistributionFamily.quasibinomial, DistributionFamily.fractionalbinomial).contains(distribution)) {
                return 2;
            }
            return 1;
        }
        return super.nclasses();
    }

    public void init(boolean expensive) {
        this.expandBaseModels();
        super.init(expensive);
        if (((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._distribution != DistributionFamily.AUTO) {
            throw new H2OIllegalArgumentException("Setting \"distribution\" to StackedEnsemble is unsupported. Please set it in \"metalearner_parameters\".");
        }
        StackedEnsemble.checkColumnPresent("fold", ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_fold_column, this.train(), this.valid(), ((StackedEnsembleModel.StackedEnsembleParameters)this._parms).blending());
        StackedEnsemble.checkColumnPresent("weights", ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._weights_column, this.train(), this.valid(), ((StackedEnsembleModel.StackedEnsembleParameters)this._parms).blending());
        StackedEnsemble.checkColumnPresent("offset", ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._offset_column, this.train(), this.valid(), ((StackedEnsembleModel.StackedEnsembleParameters)this._parms).blending());
        this.validateBaseModels();
    }

    private void expandBaseModels() {
        if (((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models == null) {
            return;
        }
        ArrayList<Key<Model>> baseModels = new ArrayList<Key<Model>>();
        for (Key<Model> baseModelKey : ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models) {
            Iced retrievedObject = DKV.getGet(baseModelKey);
            if (retrievedObject instanceof Model) {
                baseModels.add(baseModelKey);
                continue;
            }
            if (retrievedObject instanceof Grid) {
                Grid grid = (Grid)retrievedObject;
                Collections.addAll(baseModels, grid.getModelKeys());
                continue;
            }
            if (retrievedObject == null) {
                throw new IllegalArgumentException(String.format("Specified id \"%s\" does not exist.", baseModelKey));
            }
            throw new IllegalArgumentException(String.format("Unsupported type \"%s\" as a base model.", retrievedObject.getClass().toString()));
        }
        ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models = baseModels.toArray(new Key[0]);
    }

    private void validateBaseModels() {
        if (((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models == null) {
            return;
        }
        boolean warnSameWeightsColumns = true;
        String referenceWeightsColumn = null;
        for (int i = 0; i < ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models.length; ++i) {
            Model baseModel = (Model)DKV.getGet(((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models[i]);
            if (i == 0) {
                if (((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._offset_column == null) {
                    ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._offset_column = baseModel._parms._offset_column;
                }
                boolean bl = warnSameWeightsColumns = (referenceWeightsColumn = baseModel._parms._weights_column) != null;
            }
            if (!Objects.equals(referenceWeightsColumn, baseModel._parms._weights_column)) {
                warnSameWeightsColumns = false;
            }
            if (Objects.equals(((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._offset_column, baseModel._parms._offset_column)) continue;
            throw new IllegalArgumentException("All base models must have the same offset_column!");
        }
        if (((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._weights_column == null && warnSameWeightsColumns && ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models.length > 0) {
            this.warn("_weights_column", "All base models use weights_column=\"" + referenceWeightsColumn + "\" but Stacked Ensemble does not. If you want to use the same weights_column for the meta learner, please specify it as an argument in the h2o.stackedEnsemble call.");
        }
    }

    private static void checkColumnPresent(String columnName, String columnId, Frame ... frames) {
        if (columnId == null) {
            return;
        }
        for (Frame frame : frames) {
            if (frame == null || frame.vec(columnId) != null) continue;
            throw new IllegalArgumentException(String.format("Specified %s column '%s' not found in one of the supplied data frames. Available column names are: %s", columnName, columnId, Arrays.toString(frame.names())));
        }
    }

    static void addModelPredictionsToLevelOneFrame(Model aModel, Frame aModelsPredictions, Frame levelOneFrame) {
        if (aModel._output.isBinomialClassifier()) {
            Vec preds = aModelsPredictions.vec(2);
            levelOneFrame.add(aModel._key.toString(), preds);
        } else if (aModel._output.isMultinomialClassifier()) {
            Frame probabilities = aModelsPredictions.subframe(ArrayUtils.remove((String[])aModelsPredictions.names(), (String)"predict"));
            probabilities.setNames((String[])Stream.of(probabilities.names()).map(name -> aModel._key.toString().concat("/").concat((String)name)).toArray(String[]::new));
            levelOneFrame.add(probabilities);
        } else {
            if (aModel._output.isAutoencoder()) {
                throw new H2OIllegalArgumentException("Don't yet know how to stack autoencoders: " + aModel._key);
            }
            if (!aModel._output.isSupervised()) {
                throw new H2OIllegalArgumentException("Don't yet know how to stack unsupervised models: " + aModel._key);
            }
            Vec preds = aModelsPredictions.vec("predict");
            levelOneFrame.add(aModel._key.toString(), preds);
        }
    }

    static void addNonPredictorsToLevelOneFrame(StackedEnsembleModel.StackedEnsembleParameters parms, Frame fr, Frame levelOneFrame, boolean training) {
        if (training && parms._metalearner_fold_column != null) {
            levelOneFrame.add(parms._metalearner_fold_column, fr.vec(parms._metalearner_fold_column));
        }
        if (parms._weights_column != null) {
            levelOneFrame.add(parms._weights_column, fr.vec(parms._weights_column));
        }
        if (parms._offset_column != null) {
            levelOneFrame.add(parms._offset_column, fr.vec(parms._offset_column));
        }
        levelOneFrame.add(parms._response_column, fr.vec(parms._response_column));
    }

    private void inheritDistributionAndParms(StackedEnsembleModel seModel, Model.Parameters baseModelParms) {
        if (baseModelParms instanceof GLMModel.GLMParameters) {
            try {
                ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters.setDistributionFamily(DistributionUtils.familyToDistribution(((GLMModel.GLMParameters)baseModelParms)._family));
            }
            catch (IllegalArgumentException e) {
                this.warn("distribution", "Stacked Ensemble is not able to inherit distribution from GLM's family " + (Object)((Object)((GLMModel.GLMParameters)baseModelParms)._family) + ".");
            }
        } else if (baseModelParms instanceof DRFModel.DRFParameters) {
            this.inferBasicDistribution(seModel);
        } else {
            ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters.setDistributionFamily(baseModelParms._distribution);
        }
        switch (baseModelParms._distribution) {
            case custom: {
                ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters._custom_distribution_func = baseModelParms._custom_distribution_func;
                break;
            }
            case huber: {
                ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters._huber_alpha = baseModelParms._huber_alpha;
                break;
            }
            case tweedie: {
                ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters._tweedie_power = baseModelParms._tweedie_power;
                break;
            }
            case quantile: {
                ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters._quantile_alpha = baseModelParms._quantile_alpha;
            }
        }
    }

    void inferBasicDistribution(StackedEnsembleModel seModel) {
        if (((StackedEnsembleModel.StackedEnsembleOutput)seModel._output).isBinomialClassifier()) {
            ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters.setDistributionFamily(DistributionFamily.bernoulli);
        } else if (((StackedEnsembleModel.StackedEnsembleOutput)seModel._output).isClassifier()) {
            ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters.setDistributionFamily(DistributionFamily.multinomial);
        } else {
            ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters.setDistributionFamily(DistributionFamily.gaussian);
        }
    }

    private void inheritFamilyAndParms(StackedEnsembleModel seModel, Model.Parameters baseModelParms) {
        GLMModel.GLMParameters metaParams = (GLMModel.GLMParameters)((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters;
        if (baseModelParms instanceof GLMModel.GLMParameters) {
            GLMModel.GLMParameters glmParams = (GLMModel.GLMParameters)baseModelParms;
            metaParams._family = glmParams._family;
            metaParams._link = glmParams._link;
        } else if (baseModelParms instanceof DRFModel.DRFParameters) {
            this.inferBasicDistribution(seModel);
        } else {
            try {
                metaParams.setDistributionFamily(baseModelParms._distribution);
            }
            catch (H2OIllegalArgumentException e) {
                this.warn("distribution", "Stacked Ensemble is not able to inherit family from a distribution " + baseModelParms._distribution + ".");
                this.inferBasicDistribution(seModel);
            }
        }
        if (metaParams._family == GLMModel.GLMParameters.Family.tweedie) {
            ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters._tweedie_power = baseModelParms._tweedie_power;
        }
    }

    boolean inferDistributionOrFamily(StackedEnsembleModel seModel, Model aModel) {
        if (Metalearners.getActualMetalearnerAlgo(((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_algorithm) == Metalearner.Algorithm.glm) {
            if (((GLMModel.GLMParameters)((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters)._family != GLMModel.GLMParameters.Family.AUTO) {
                return false;
            }
            this.inheritFamilyAndParms(seModel, aModel._parms);
        } else {
            if (((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters._distribution != DistributionFamily.AUTO) {
                return false;
            }
            this.inheritDistributionAndParms(seModel, aModel._parms);
        }
        return true;
    }

    private DistributionFamily distributionFamily(Model aModel) {
        if (aModel instanceof DRFModel) {
            if (aModel._output.isBinomialClassifier()) {
                return DistributionFamily.bernoulli;
            }
            if (aModel._output.isClassifier()) {
                return DistributionFamily.multinomial;
            }
            return DistributionFamily.gaussian;
        }
        if (aModel instanceof StackedEnsembleModel) {
            StackedEnsembleModel seModel = (StackedEnsembleModel)aModel;
            if (Metalearners.getActualMetalearnerAlgo(((StackedEnsembleModel.StackedEnsembleParameters)seModel._parms)._metalearner_algorithm) == Metalearner.Algorithm.glm) {
                return DistributionUtils.familyToDistribution(((GLMModel.GLMParameters)((StackedEnsembleModel.StackedEnsembleParameters)seModel._parms)._metalearner_parameters)._family);
            }
            if (((StackedEnsembleModel.StackedEnsembleParameters)seModel._parms)._metalearner_parameters._distribution != DistributionFamily.AUTO) {
                return ((StackedEnsembleModel.StackedEnsembleParameters)seModel._parms)._metalearner_parameters._distribution;
            }
        }
        try {
            Field distributionField;
            Field familyField = ReflectionUtils.findNamedField((Object)aModel._parms, (String)"_family");
            Field field = distributionField = familyField != null ? null : ReflectionUtils.findNamedField((Object)aModel, (String)"_dist");
            if (null != familyField) {
                GLMModel.GLMParameters.Family thisFamily = (GLMModel.GLMParameters.Family)((Object)familyField.get(aModel._parms));
                return DistributionUtils.familyToDistribution(thisFamily);
            }
            if (null != distributionField) {
                Distribution distribution = (Distribution)distributionField.get(aModel);
                DistributionFamily distributionFamily = null != distribution ? distribution._family : aModel._parms._distribution;
                if (distributionFamily == DistributionFamily.AUTO) {
                    distributionFamily = aModel._output.isBinomialClassifier() ? DistributionFamily.bernoulli : (aModel._output.isClassifier() ? DistributionFamily.multinomial : DistributionFamily.gaussian);
                }
                return distributionFamily;
            }
            throw new H2OIllegalArgumentException("Don't know how to stack models that have neither a distribution hyperparameter nor a family hyperparameter.");
        }
        catch (Exception e) {
            throw new H2OIllegalArgumentException(e.toString(), e.toString());
        }
    }

    void checkAndInheritModelProperties(StackedEnsembleModel seModel) {
        if (null == ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models || 0 == ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models.length) {
            throw new H2OIllegalArgumentException("When creating a StackedEnsemble you must specify one or more models; found 0.");
        }
        if (null != ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_fold_column && 0 != ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_nfolds) {
            throw new H2OIllegalArgumentException("Cannot specify fold_column and nfolds at the same time.");
        }
        Model aModel = null;
        boolean retrievedFirstModelParams = false;
        boolean inferredDistributionFromFirstModel = false;
        GLMModel firstGLM = null;
        boolean blending_mode = ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._blending != null;
        boolean cv_required_on_base_model = !blending_mode;
        boolean require_consistent_training_frames = !blending_mode && !((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._is_cv_model;
        int basemodel_nfolds = -1;
        Model.Parameters.FoldAssignmentScheme basemodel_fold_assignment = null;
        String basemodel_fold_column = null;
        long seed = -1L;
        if (((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters == null) {
            ((StackedEnsembleModel.StackedEnsembleParameters)this._parms).initMetalearnerParams();
        }
        for (Key<Model> k : ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models) {
            aModel = (Model)DKV.getGet(k);
            if (null == aModel) {
                this.warn("base_models", "Failed to find base model; skipping: " + k);
                continue;
            }
            Log.debug((Object[])new Object[]{"Checking properties for model " + k});
            if (!aModel.isSupervised()) {
                throw new H2OIllegalArgumentException("Base model is not supervised: " + aModel._key.toString());
            }
            if (retrievedFirstModelParams) {
                if (seModel.modelCategory != aModel._output.getModelCategory()) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of different categories of models among " + Arrays.toString(((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models));
                }
                if (!seModel.responseColumn.equals(aModel._parms._response_column)) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: they use different response columns. Found: " + seModel.responseColumn + " (StackedEnsemble) and " + aModel._parms._response_column + " (model " + k + ").");
                }
                if (require_consistent_training_frames) {
                    long numOfRowsUsedToTrain;
                    if (seModel.trainingFrameRows < 0L) {
                        seModel.trainingFrameRows = ((StackedEnsembleModel.StackedEnsembleParameters)this._parms).train().numRows();
                    }
                    long l = numOfRowsUsedToTrain = aModel._parms.train() == null ? ((Frame)aModel._output._cross_validation_holdout_predictions_frame_id.get()).numRows() : aModel._parms.train().numRows();
                    if (seModel.trainingFrameRows != numOfRowsUsedToTrain) {
                        throw new H2OIllegalArgumentException("Base models are inconsistent: they use different size (number of rows) training frames. Found: " + seModel.trainingFrameRows + " (StackedEnsemble) and " + numOfRowsUsedToTrain + " (model " + k + ").");
                    }
                }
                if (cv_required_on_base_model) {
                    if (aModel._parms._fold_assignment != basemodel_fold_assignment && (aModel._parms._fold_assignment != Model.Parameters.FoldAssignmentScheme.AUTO || basemodel_fold_assignment != Model.Parameters.FoldAssignmentScheme.Random)) {
                        this.warn("base_models", "Base models are inconsistent: they use different fold_assignments. This can lead to data leakage.");
                    }
                    if (aModel._parms._fold_column == null) {
                        if (aModel._parms._nfolds < 2) {
                            throw new H2OIllegalArgumentException("Base model does not use cross-validation: " + aModel._parms._nfolds);
                        }
                        if (basemodel_nfolds != aModel._parms._nfolds) {
                            this.warn("base_models", "Base models are inconsistent: they use different values for nfolds. This can lead to data leakage.");
                        }
                        if (basemodel_fold_assignment == Model.Parameters.FoldAssignmentScheme.Random && aModel._parms._seed != seed) {
                            this.warn("base_models", "Base models are inconsistent: they use random-seeded k-fold cross-validation but have different seeds. This can lead to data leakage.");
                        }
                    } else if (!aModel._parms._fold_column.equals(basemodel_fold_column)) {
                        this.warn("base_models", "Base models are inconsistent: they use different fold_columns. This can lead to data leakage.");
                    }
                    if (!aModel._parms._keep_cross_validation_predictions) {
                        throw new H2OIllegalArgumentException("Base model does not keep cross-validation predictions: " + aModel._parms._nfolds);
                    }
                }
                if (!inferredDistributionFromFirstModel) continue;
                if (!(aModel instanceof DRFModel) && this.distributionFamily(aModel) == this.distributionFamily(seModel)) {
                    boolean sameParams = true;
                    switch (((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters._distribution) {
                        case custom: {
                            sameParams = ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters._custom_distribution_func.equals(aModel._parms._custom_distribution_func);
                            break;
                        }
                        case huber: {
                            sameParams = ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters._huber_alpha == aModel._parms._huber_alpha;
                            break;
                        }
                        case tweedie: {
                            sameParams = ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters._tweedie_power == aModel._parms._tweedie_power;
                            break;
                        }
                        case quantile: {
                            boolean bl = sameParams = ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters._quantile_alpha == aModel._parms._quantile_alpha;
                        }
                    }
                    if (aModel instanceof GLMModel && Metalearners.getActualMetalearnerAlgo(((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_algorithm) == Metalearner.Algorithm.glm) {
                        if (firstGLM == null) {
                            firstGLM = (GLMModel)aModel;
                            this.inheritFamilyAndParms(seModel, firstGLM._parms);
                        } else {
                            sameParams = ((GLMModel.GLMParameters)((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters)._link.equals((Object)((GLMModel.GLMParameters)((GLMModel)aModel)._parms)._link);
                        }
                    }
                    if (sameParams) continue;
                    this.warn("distribution", "Base models are inconsistent; they use same distribution but different parameters of the distribution. Reverting to default distribution.");
                    this.inferBasicDistribution(seModel);
                    inferredDistributionFromFirstModel = false;
                    continue;
                }
                if (this.distributionFamily(aModel) != this.distributionFamily(seModel)) {
                    this.warn("distribution", "Base models are inconsistent; they use different distributions: " + this.distributionFamily(seModel) + " and: " + this.distributionFamily(aModel) + ". Reverting to default distribution.");
                }
                this.inferBasicDistribution(seModel);
                inferredDistributionFromFirstModel = false;
                continue;
            }
            seModel.modelCategory = aModel._output.getModelCategory();
            inferredDistributionFromFirstModel = this.inferDistributionOrFamily(seModel, aModel);
            firstGLM = aModel instanceof GLMModel && inferredDistributionFromFirstModel ? (GLMModel)aModel : null;
            seModel.responseColumn = aModel._parms._response_column;
            if (!((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._response_column.equals(seModel.responseColumn)) {
                throw new H2OIllegalArgumentException("StackedModel response_column must match the response_column of each base model. Found: " + ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._response_column + "(StackedEnsemble) and: " + seModel.responseColumn + " (model " + k + ").");
            }
            basemodel_nfolds = aModel._parms._nfolds;
            basemodel_fold_assignment = aModel._parms._fold_assignment;
            if (basemodel_fold_assignment == Model.Parameters.FoldAssignmentScheme.AUTO) {
                basemodel_fold_assignment = Model.Parameters.FoldAssignmentScheme.Random;
            }
            basemodel_fold_column = aModel._parms._fold_column;
            seed = aModel._parms._seed;
            retrievedFirstModelParams = true;
        }
        if (null == aModel) {
            throw new H2OIllegalArgumentException("When creating a StackedEnsemble you must specify one or more models; " + ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models.length + " were specified but none of those were found: " + Arrays.toString(((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models));
        }
    }

    private class StackedEnsembleBlendingDriver
    extends StackedEnsembleDriver {
        private StackedEnsembleBlendingDriver() {
        }

        @Override
        protected StackedEnsembleModel.StackingStrategy strategy() {
            return StackedEnsembleModel.StackingStrategy.blending;
        }

        @Override
        protected Frame getActualTrainingFrame() {
            return ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms).blending();
        }

        @Override
        protected Frame getPredictionsForBaseModel(Model model, Frame actualsFrame, boolean isTrainingFrame) {
            if (StackedEnsemble.this.stop_requested() && isTrainingFrame) {
                throw new Job.JobCancelledException();
            }
            return this.buildPredictionsForBaseModel(model, actualsFrame);
        }
    }

    private class StackedEnsembleCVStackingDriver
    extends StackedEnsembleDriver {
        private StackedEnsembleCVStackingDriver() {
        }

        @Override
        protected StackedEnsembleModel.StackingStrategy strategy() {
            return StackedEnsembleModel.StackingStrategy.cross_validation;
        }

        @Override
        protected Frame getActualTrainingFrame() {
            return ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms).train();
        }

        @Override
        protected Frame getPredictionsForBaseModel(Model model, Frame actualsFrame, boolean isTraining) {
            Frame fr;
            if (isTraining) {
                if (null == model._output._cross_validation_holdout_predictions_frame_id) {
                    throw new H2OIllegalArgumentException("Failed to find the xval predictions frame id. . .  Looks like keep_cross_validation_predictions wasn't set when building the models.");
                }
                fr = (Frame)DKV.getGet((Key)model._output._cross_validation_holdout_predictions_frame_id);
                if (null == fr) {
                    throw new H2OIllegalArgumentException("Failed to find the xval predictions frame. . .  Looks like keep_cross_validation_predictions wasn't set when building the models, or the frame was deleted.");
                }
            } else {
                fr = this.buildPredictionsForBaseModel(model, actualsFrame);
            }
            return fr;
        }
    }

    private abstract class StackedEnsembleDriver
    extends ModelBuilder.Driver {
        private StackedEnsembleDriver() {
            super((ModelBuilder)StackedEnsemble.this);
        }

        private Frame prepareLevelOneFrame(String levelOneKey, Model[] baseModels, Frame[] baseModelPredictions, Frame actuals) {
            Frame old;
            StackedEnsembleModel.StackedEnsembleParameters.MetalearnerTransform transform;
            if (null == baseModels) {
                throw new H2OIllegalArgumentException("Base models array is null.");
            }
            if (null == baseModelPredictions) {
                throw new H2OIllegalArgumentException("Base model predictions array is null.");
            }
            if (baseModels.length == 0) {
                throw new H2OIllegalArgumentException("Base models array is empty.");
            }
            if (baseModelPredictions.length == 0) {
                throw new H2OIllegalArgumentException("Base model predictions array is empty.");
            }
            if (baseModels.length != baseModelPredictions.length) {
                throw new H2OIllegalArgumentException("Base models and prediction arrays are different lengths.");
            }
            if (((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms)._metalearner_transform != null && ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms)._metalearner_transform != StackedEnsembleModel.StackedEnsembleParameters.MetalearnerTransform.NONE) {
                if (!((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output).isBinomialClassifier() && !((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output).isMultinomialClassifier()) {
                    throw new H2OIllegalArgumentException("Metalearner transform is supported only for classification!");
                }
                transform = ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms)._metalearner_transform;
            } else {
                transform = null;
            }
            if (null == levelOneKey) {
                levelOneKey = "levelone_" + StackedEnsemble.this._model._key.toString() + "_" + ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms)._metalearner_transform.toString();
            }
            if ((old = (Frame)DKV.getGet((String)levelOneKey)) != null && old instanceof Frame) {
                Frame oldFrame = old;
                oldFrame.write_lock(StackedEnsemble.this._job);
                oldFrame.removeAll();
                oldFrame.update(StackedEnsemble.this._job);
                oldFrame.unlock(StackedEnsemble.this._job);
            }
            Frame levelOneFrame = transform == null ? new Frame(Key.make((String)levelOneKey)) : new Frame(new Vec[0]);
            for (int i = 0; i < baseModels.length; ++i) {
                Model baseModel = baseModels[i];
                Frame baseModelPreds = baseModelPredictions[i];
                if (null == baseModel) {
                    Log.warn((Object[])new Object[]{"Failed to find base model; skipping: " + baseModels[i]});
                    continue;
                }
                if (null == baseModelPreds) {
                    Log.warn((Object[])new Object[]{"Failed to find base model " + baseModel + " predictions; skipping: " + baseModelPreds._key});
                    continue;
                }
                StackedEnsemble.addModelPredictionsToLevelOneFrame(baseModel, baseModelPreds, levelOneFrame);
                Scope.untrack((Frame[])baseModelPredictions);
            }
            if (transform != null) {
                levelOneFrame = ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms)._metalearner_transform.transform(StackedEnsemble.this._model, levelOneFrame, (Key<Frame>)Key.make((String)levelOneKey));
            }
            StackedEnsemble.addNonPredictorsToLevelOneFrame((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms, actuals, levelOneFrame, true);
            Log.info((Object[])new Object[]{"Finished creating \"level one\" frame for stacking: " + levelOneFrame.toString()});
            DKV.put((Keyed)levelOneFrame);
            return levelOneFrame;
        }

        private Frame prepareLevelOneFrame(String levelOneKey, Key<Model>[] baseModelKeys, Frame actuals, boolean isTraining) {
            ArrayList<Model> baseModels = new ArrayList<Model>();
            ArrayList<Frame> baseModelPredictions = new ArrayList<Frame>();
            for (Key<Model> k : baseModelKeys) {
                if (((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._metalearner != null && !StackedEnsemble.this._model.isUsefulBaseModel(k)) continue;
                Model aModel = (Model)DKV.getGet(k);
                if (null == aModel) {
                    throw new H2OIllegalArgumentException("Failed to find base model: " + k);
                }
                Frame predictions = this.getPredictionsForBaseModel(aModel, actuals, isTraining);
                baseModels.add(aModel);
                baseModelPredictions.add(predictions);
            }
            boolean keepLevelOneFrame = isTraining && ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms)._keep_levelone_frame;
            Frame levelOneFrame = this.prepareLevelOneFrame(levelOneKey, baseModels.toArray(new Model[0]), baseModelPredictions.toArray(new Frame[0]), actuals);
            if (keepLevelOneFrame) {
                levelOneFrame = levelOneFrame.deepCopy(levelOneFrame._key.toString());
                levelOneFrame.write_lock(StackedEnsemble.this._job);
                levelOneFrame.update(StackedEnsemble.this._job);
                levelOneFrame.unlock(StackedEnsemble.this._job);
                Scope.untrack((Iterable)levelOneFrame.keysList());
            }
            return levelOneFrame;
        }

        public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) {
            if (StackedEnsemble.this._model != null) {
                StackedEnsemble.this._model.delete();
            }
            return super.onExceptionalCompletion(ex, caller);
        }

        protected Frame buildPredictionsForBaseModel(Model model, Frame frame) {
            Key<Frame> predsKey = this.buildPredsKey(model, frame);
            Frame preds = (Frame)DKV.getGet(predsKey);
            if (preds == null) {
                preds = model.score(frame, predsKey.toString(), null, false);
                Scope.untrack((Iterable)preds.keysList());
            }
            if (((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._base_model_predictions_keys == null) {
                ((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._base_model_predictions_keys = new Key[0];
            }
            if (!ArrayUtils.contains((Object[])((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._base_model_predictions_keys, predsKey)) {
                ((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._base_model_predictions_keys = (Key[])ArrayUtils.append((Object[])((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._base_model_predictions_keys, (Object[])new Key[]{predsKey});
            }
            return preds;
        }

        TwoDimTable generateModelSummary() {
            HashMap<String, Integer> baseModelTypes = new HashMap<String, Integer>();
            HashMap<String, Integer> usedBaseModelTypes = new HashMap<String, Integer>();
            for (Key<Model> bmk : ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._base_models) {
                Model bm = (Model)bmk.get();
                if (StackedEnsemble.this._model.isUsefulBaseModel(bmk)) {
                    usedBaseModelTypes.put(bm._parms.algoName(), usedBaseModelTypes.containsKey(bm._parms.algoName()) ? (Integer)usedBaseModelTypes.get(bm._parms.algoName()) + 1 : 1);
                }
                baseModelTypes.put(bm._parms.algoName(), baseModelTypes.containsKey(bm._parms.algoName()) ? (Integer)baseModelTypes.get(bm._parms.algoName()) + 1 : 1);
            }
            ArrayList<String> rowHeaders = new ArrayList<String>();
            ArrayList<String> rowValues = new ArrayList<String>();
            rowHeaders.add("Stacking strategy");
            rowValues.add(((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._stacking_strategy.toString());
            rowHeaders.add("Number of base models (used / total)");
            rowValues.add(Arrays.stream(((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._base_models).filter(StackedEnsemble.this._model::isUsefulBaseModel).count() + "/" + ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._base_models.length);
            for (Map.Entry baseModelType : baseModelTypes.entrySet()) {
                rowHeaders.add("# " + (String)baseModelType.getKey() + " base models (used / total)");
                rowValues.add((usedBaseModelTypes.containsKey(baseModelType.getKey()) ? (Serializable)usedBaseModelTypes.get(baseModelType.getKey()) : "0") + "/" + baseModelType.getValue());
            }
            rowHeaders.add("Metalearner algorithm");
            rowValues.add(((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._metalearner._parms.algoName());
            rowHeaders.add("Metalearner fold assignment scheme");
            rowValues.add(((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._metalearner._parms._fold_assignment == null ? "AUTO" : ((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._metalearner._parms._fold_assignment.name());
            rowHeaders.add("Metalearner nfolds");
            rowValues.add("" + ((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._metalearner._parms._nfolds);
            rowHeaders.add("Metalearner fold_column");
            rowValues.add(((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._metalearner._parms._fold_column);
            rowHeaders.add("Custom metalearner hyperparameters");
            rowValues.add(((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_params.isEmpty() ? "None" : ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_params);
            TwoDimTable ms = new TwoDimTable("Model Summary for Stacked Ensemble", "", rowHeaders.toArray(new String[0]), new String[]{"Value"}, new String[]{"string"}, new String[]{"%s"}, "Key");
            int i = 0;
            for (String val : rowValues) {
                ms.set(i++, 0, (Object)val);
            }
            return ms;
        }

        protected abstract StackedEnsembleModel.StackingStrategy strategy();

        protected abstract Frame getActualTrainingFrame();

        protected abstract Frame getPredictionsForBaseModel(Model var1, Frame var2, boolean var3);

        private Key<Frame> buildPredsKey(Key model_key, long model_checksum, Key frame_key, long frame_checksum) {
            return Key.make((String)("preds_" + model_checksum + "_on_" + frame_checksum));
        }

        protected Key<Frame> buildPredsKey(Model model, Frame frame) {
            return frame == null || model == null ? null : this.buildPredsKey(model._key, model.checksum(), frame._key, frame.checksum());
        }

        public void computeImpl() {
            Metalearner.Algorithm metalearnerAlgoSpec;
            Metalearner.Algorithm metalearnerAlgoImpl;
            StackedEnsemble.this.init(true);
            if (StackedEnsemble.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder((ModelBuilder)StackedEnsemble.this);
            }
            StackedEnsemble.this._model = new StackedEnsembleModel(StackedEnsemble.this.dest(), (StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms, new StackedEnsembleModel.StackedEnsembleOutput(StackedEnsemble.this));
            ((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._stacking_strategy = this.strategy();
            try {
                StackedEnsemble.this._model.delete_and_lock(StackedEnsemble.this._job);
                StackedEnsemble.this.checkAndInheritModelProperties(StackedEnsemble.this._model);
                StackedEnsemble.this._model.update(StackedEnsemble.this._job);
            }
            finally {
                StackedEnsemble.this._model.unlock(StackedEnsemble.this._job);
            }
            String levelOneTrainKey = "levelone_training_" + StackedEnsemble.this._model._key.toString();
            Frame levelOneTrainingFrame = this.prepareLevelOneFrame(levelOneTrainKey, ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._base_models, this.getActualTrainingFrame(), true);
            Frame levelOneValidationFrame = null;
            if (((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms).valid() != null) {
                String levelOneValidKey = "levelone_validation_" + StackedEnsemble.this._model._key.toString();
                levelOneValidationFrame = this.prepareLevelOneFrame(levelOneValidKey, ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._base_models, ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms).valid(), false);
            }
            if ((metalearnerAlgoImpl = Metalearners.getActualMetalearnerAlgo(metalearnerAlgoSpec = ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_algorithm)) == null) {
                throw new H2OIllegalArgumentException("Invalid `metalearner_algorithm`. Passed in " + (Object)((Object)metalearnerAlgoSpec) + " but must be one of " + Arrays.toString((Object[])Metalearner.Algorithm.values()));
            }
            Key metalearnerKey = Key.make((String)("metalearner_" + (Object)((Object)metalearnerAlgoSpec) + "_" + StackedEnsemble.this._model._key));
            Job metalearnerJob = new Job(metalearnerKey, ModelBuilder.javaName((String)metalearnerAlgoImpl.toString()), "StackingEnsemble metalearner (" + (Object)((Object)metalearnerAlgoSpec) + ")");
            boolean hasMetaLearnerParams = ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_parameters != null;
            long metalearnerSeed = ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._seed;
            Metalearner metalearner = Metalearners.createInstance(metalearnerAlgoSpec.name());
            metalearner.init(levelOneTrainingFrame, levelOneValidationFrame, ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_parameters, StackedEnsemble.this._model, StackedEnsemble.this._job, (Key<Model>)metalearnerKey, metalearnerJob, (StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms, hasMetaLearnerParams, metalearnerSeed, ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms)._max_runtime_secs == 0.0 ? 0L : Math.max(StackedEnsemble.this.remainingTimeSecs(), 1L));
            metalearner.compute();
            if (StackedEnsemble.this._model.evalAutoParamsEnabled && ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_algorithm == Metalearner.Algorithm.AUTO) {
                ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_algorithm = metalearnerAlgoImpl;
            }
            ((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._model_summary = this.generateModelSummary();
        }
    }
}

