/*
 * Decompiled with CFR 0.152.
 */
package hex.ensemble;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ensemble.Metalearner;
import hex.ensemble.StackedEnsembleModel;
import java.util.ArrayList;
import java.util.Arrays;
import water.DKV;
import water.Job;
import water.Key;
import water.Keyed;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

public class StackedEnsemble
extends ModelBuilder<StackedEnsembleModel, StackedEnsembleModel.StackedEnsembleParameters, StackedEnsembleModel.StackedEnsembleOutput> {
    StackedEnsembleDriver _driver;
    protected StackedEnsembleModel _model;

    public StackedEnsemble(StackedEnsembleModel.StackedEnsembleParameters parms) {
        super((Model.Parameters)parms);
        this.init(false);
    }

    public StackedEnsemble(boolean startup_once) {
        super((Model.Parameters)new StackedEnsembleModel.StackedEnsembleParameters(), startup_once);
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    public boolean isSupervised() {
        return true;
    }

    protected StackedEnsembleDriver trainModelImpl() {
        this._driver = ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._blending == null ? new StackedEnsembleCVStackingDriver() : new StackedEnsembleBlendingDriver();
        return this._driver;
    }

    public boolean haveMojo() {
        return true;
    }

    static void addModelPredictionsToLevelOneFrame(Model aModel, Frame aModelsPredictions, Frame levelOneFrame) {
        if (aModel._output.isBinomialClassifier()) {
            Vec preds = aModelsPredictions.vec(2);
            levelOneFrame.add(aModel._key.toString(), preds);
        } else if (aModel._output.isMultinomialClassifier()) {
            Frame probabilities = aModelsPredictions.subframe(ArrayUtils.remove((String[])aModelsPredictions.names(), (String)"predict"));
            levelOneFrame.add(probabilities);
        } else {
            if (aModel._output.isAutoencoder()) {
                throw new H2OIllegalArgumentException("Don't yet know how to stack autoencoders: " + aModel._key);
            }
            if (!aModel._output.isSupervised()) {
                throw new H2OIllegalArgumentException("Don't yet know how to stack unsupervised models: " + aModel._key);
            }
            Vec preds = aModelsPredictions.vec("predict");
            levelOneFrame.add(aModel._key.toString(), preds);
        }
    }

    private class StackedEnsembleBlendingDriver
    extends StackedEnsembleDriver {
        private StackedEnsembleBlendingDriver() {
        }

        @Override
        protected StackedEnsembleModel.StackingStrategy strategy() {
            return StackedEnsembleModel.StackingStrategy.blending;
        }

        @Override
        protected Frame getActualTrainingFrame() {
            return ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms).blending();
        }

        @Override
        protected Frame getPredictionsForBaseModel(Model model, Frame actualsFrame, boolean isTrainingFrame) {
            return this.buildPredictionsForBaseModel(model, actualsFrame);
        }
    }

    private class StackedEnsembleCVStackingDriver
    extends StackedEnsembleDriver {
        private StackedEnsembleCVStackingDriver() {
        }

        @Override
        protected StackedEnsembleModel.StackingStrategy strategy() {
            return StackedEnsembleModel.StackingStrategy.cross_validation;
        }

        @Override
        protected Frame getActualTrainingFrame() {
            return ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms).train();
        }

        @Override
        protected Frame getPredictionsForBaseModel(Model model, Frame actualsFrame, boolean isTraining) {
            Frame fr;
            if (isTraining) {
                if (null == model._output._cross_validation_holdout_predictions_frame_id) {
                    throw new H2OIllegalArgumentException("Failed to find the xval predictions frame id. . .  Looks like keep_cross_validation_predictions wasn't set when building the models.");
                }
                fr = (Frame)DKV.getGet((Key)model._output._cross_validation_holdout_predictions_frame_id);
                if (null == fr) {
                    throw new H2OIllegalArgumentException("Failed to find the xval predictions frame. . .  Looks like keep_cross_validation_predictions wasn't set when building the models, or the frame was deleted.");
                }
            } else {
                fr = this.buildPredictionsForBaseModel(model, actualsFrame);
            }
            return fr;
        }
    }

    private abstract class StackedEnsembleDriver
    extends ModelBuilder.Driver {
        private StackedEnsembleDriver() {
            super((ModelBuilder)StackedEnsemble.this);
        }

        private Frame prepareLevelOneFrame(String levelOneKey, Model[] baseModels, Frame[] baseModelPredictions, Frame actuals) {
            if (null == baseModels) {
                throw new H2OIllegalArgumentException("Base models array is null.");
            }
            if (null == baseModelPredictions) {
                throw new H2OIllegalArgumentException("Base model predictions array is null.");
            }
            if (baseModels.length == 0) {
                throw new H2OIllegalArgumentException("Base models array is empty.");
            }
            if (baseModelPredictions.length == 0) {
                throw new H2OIllegalArgumentException("Base model predictions array is empty.");
            }
            if (baseModels.length != baseModelPredictions.length) {
                throw new H2OIllegalArgumentException("Base models and prediction arrays are different lengths.");
            }
            if (null == levelOneKey) {
                levelOneKey = "levelone_" + StackedEnsemble.this._model._key.toString();
            }
            Frame levelOneFrame = new Frame(Key.make((String)levelOneKey));
            for (int i = 0; i < baseModels.length; ++i) {
                Model baseModel = baseModels[i];
                Frame baseModelPreds = baseModelPredictions[i];
                if (null == baseModel) {
                    Log.warn((Object[])new Object[]{"Failed to find base model; skipping: " + baseModels[i]});
                    continue;
                }
                if (null == baseModelPreds) {
                    Log.warn((Object[])new Object[]{"Failed to find base model " + baseModel + " predictions; skipping: " + baseModelPreds._key});
                    continue;
                }
                StackedEnsemble.addModelPredictionsToLevelOneFrame(baseModel, baseModelPreds, levelOneFrame);
            }
            if (((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_fold_column != null) {
                Vec foldColumn = actuals.vec(((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_fold_column);
                levelOneFrame.add(((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_fold_column, foldColumn);
            }
            Vec responseColumn = actuals.vec(StackedEnsemble.this._model.responseColumn);
            levelOneFrame.add(StackedEnsemble.this._model.responseColumn, responseColumn);
            Frame old = (Frame)DKV.getGet((Key)levelOneFrame._key);
            if (old != null && old instanceof Frame) {
                Frame oldFrame = old;
                oldFrame.removeAll();
                oldFrame.write_lock(StackedEnsemble.this._job);
                oldFrame.update(StackedEnsemble.this._job);
                oldFrame.unlock(StackedEnsemble.this._job);
            }
            levelOneFrame.delete_and_lock(StackedEnsemble.this._job);
            levelOneFrame.unlock(StackedEnsemble.this._job);
            Log.info((Object[])new Object[]{"Finished creating \"level one\" frame for stacking: " + levelOneFrame.toString()});
            DKV.put((Keyed)levelOneFrame);
            return levelOneFrame;
        }

        private Frame prepareLevelOneFrame(String levelOneKey, Key<Model>[] baseModelKeys, Frame actuals, boolean isTraining) {
            ArrayList<Model> baseModels = new ArrayList<Model>();
            ArrayList<Frame> baseModelPredictions = new ArrayList<Frame>();
            for (Key<Model> k : baseModelKeys) {
                Model aModel = (Model)DKV.getGet(k);
                if (null == aModel) {
                    throw new H2OIllegalArgumentException("Failed to find base model: " + k);
                }
                Frame predictions = this.getPredictionsForBaseModel(aModel, actuals, isTraining);
                baseModels.add(aModel);
                baseModelPredictions.add(predictions);
            }
            boolean keepLevelOneFrame = isTraining && ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms)._keep_levelone_frame;
            Frame levelOneFrame = this.prepareLevelOneFrame(levelOneKey, baseModels.toArray(new Model[0]), baseModelPredictions.toArray(new Frame[0]), actuals);
            if (keepLevelOneFrame) {
                levelOneFrame = levelOneFrame.deepCopy(levelOneFrame._key.toString());
                levelOneFrame.write_lock(StackedEnsemble.this._job);
                levelOneFrame.update(StackedEnsemble.this._job);
                levelOneFrame.unlock(StackedEnsemble.this._job);
                Scope.untrack((Iterable)levelOneFrame.keysList());
            }
            return levelOneFrame;
        }

        protected Frame buildPredictionsForBaseModel(Model model, Frame frame) {
            Key<Frame> predsKey = this.buildPredsKey(model, frame);
            Frame preds = (Frame)DKV.getGet(predsKey);
            if (preds == null) {
                preds = model.score(frame, predsKey.toString());
                Scope.untrack((Iterable)preds.keysList());
            }
            if (((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._base_model_predictions_keys == null) {
                ((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._base_model_predictions_keys = new Key[0];
            }
            if (!ArrayUtils.contains((Object[])((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._base_model_predictions_keys, predsKey)) {
                ((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._base_model_predictions_keys = (Key[])ArrayUtils.append((Object[])((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._base_model_predictions_keys, (Object[])new Key[]{predsKey});
            }
            return preds;
        }

        protected abstract StackedEnsembleModel.StackingStrategy strategy();

        protected abstract Frame getActualTrainingFrame();

        protected abstract Frame getPredictionsForBaseModel(Model var1, Frame var2, boolean var3);

        private Key<Frame> buildPredsKey(Key model_key, long model_checksum, Key frame_key, long frame_checksum) {
            return Key.make((String)("preds_" + model_checksum + "_on_" + frame_checksum));
        }

        protected Key<Frame> buildPredsKey(Model model, Frame frame) {
            return frame == null || model == null ? null : this.buildPredsKey(model._key, model.checksum(), frame._key, frame.checksum());
        }

        public void computeImpl() {
            Metalearner.Algorithm metalearnerAlgoSpec;
            Metalearner.Algorithm metalearnerAlgoImpl;
            StackedEnsemble.this.init(true);
            if (StackedEnsemble.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder((ModelBuilder)StackedEnsemble.this);
            }
            StackedEnsemble.this._model = new StackedEnsembleModel(StackedEnsemble.this.dest(), (StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms, new StackedEnsembleModel.StackedEnsembleOutput(StackedEnsemble.this));
            ((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._stacking_strategy = this.strategy();
            StackedEnsemble.this._model.delete_and_lock(StackedEnsemble.this._job);
            StackedEnsemble.this._model.checkAndInheritModelProperties();
            String levelOneTrainKey = "levelone_training_" + StackedEnsemble.this._model._key.toString();
            Frame levelOneTrainingFrame = this.prepareLevelOneFrame(levelOneTrainKey, ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._base_models, this.getActualTrainingFrame(), true);
            Frame levelOneValidationFrame = null;
            if (((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms).valid() != null) {
                String levelOneValidKey = "levelone_validation_" + StackedEnsemble.this._model._key.toString();
                levelOneValidationFrame = this.prepareLevelOneFrame(levelOneValidKey, ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._base_models, ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms).valid(), false);
            }
            if ((metalearnerAlgoImpl = Metalearner.getActualMetalearnerAlgo(metalearnerAlgoSpec = ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_algorithm)) == null) {
                throw new H2OIllegalArgumentException("Invalid `metalearner_algorithm`. Passed in " + (Object)((Object)metalearnerAlgoSpec) + " but must be one of " + Arrays.toString((Object[])Metalearner.Algorithm.values()));
            }
            Key metalearnerKey = Key.make((String)("metalearner_" + (Object)((Object)metalearnerAlgoSpec) + "_" + StackedEnsemble.this._model._key));
            Job metalearnerJob = new Job(metalearnerKey, ModelBuilder.javaName((String)metalearnerAlgoImpl.toString()), "StackingEnsemble metalearner (" + (Object)((Object)metalearnerAlgoSpec) + ")");
            boolean hasMetaLearnerParams = ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_parameters != null;
            long metalearnerSeed = ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._seed;
            Metalearner metalearner = Metalearner.createInstance(metalearnerAlgoSpec);
            metalearner.init(levelOneTrainingFrame, levelOneValidationFrame, ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_parameters, StackedEnsemble.this._model, StackedEnsemble.this._job, (Key<Model>)metalearnerKey, metalearnerJob, (StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms, hasMetaLearnerParams, metalearnerSeed);
            metalearner.compute();
        }
    }
}

