/*
 * Decompiled with CFR 0.152.
 */
package hex.ensemble;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.StackedEnsembleModel;
import hex.glm.GLM;
import hex.glm.GLMModel;
import water.DKV;
import water.Job;
import water.Key;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;

public class StackedEnsemble
extends ModelBuilder<StackedEnsembleModel, StackedEnsembleModel.StackedEnsembleParameters, StackedEnsembleModel.StackedEnsembleOutput> {
    StackedEnsembleDriver _driver;
    protected StackedEnsembleModel _model;

    public StackedEnsemble(boolean startup_once) {
        super((Model.Parameters)new StackedEnsembleModel.StackedEnsembleParameters(), startup_once);
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    public boolean isSupervised() {
        return true;
    }

    protected StackedEnsembleDriver trainModelImpl() {
        this._driver = new StackedEnsembleDriver();
        return this._driver;
    }

    public static void addModelPredictionsToLevelOneFrame(Model aModel, Frame aModelsPredictions, Frame levelOneFrame) {
        if (aModel._output.isBinomialClassifier()) {
            Vec preds = aModelsPredictions.vec(2);
            levelOneFrame.add(aModel._key.toString(), preds);
        } else {
            if (aModel._output.isClassifier()) {
                throw new H2OIllegalArgumentException("Don't yet know how to stack multinomial classifiers: " + aModel._key);
            }
            if (aModel._output.isAutoencoder()) {
                throw new H2OIllegalArgumentException("Don't yet know how to stack autoencoders: " + aModel._key);
            }
            if (!aModel._output.isSupervised()) {
                throw new H2OIllegalArgumentException("Don't yet know how to stack unsupervised models: " + aModel._key);
            }
            levelOneFrame.add(aModel._key.toString(), aModelsPredictions.vec("predict"));
        }
    }

    private class StackedEnsembleDriver
    extends ModelBuilder.Driver {
        private StackedEnsembleDriver() {
            super((ModelBuilder)StackedEnsemble.this);
        }

        private Frame prepareLevelOneFrame(StackedEnsembleModel.StackedEnsembleParameters parms) {
            Frame levelOneFrame = new Frame(Key.make((String)("levelone_" + StackedEnsemble.this._model._key.toString())));
            for (Key<Model> k : ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms)._base_models) {
                Model aModel = (Model)DKV.getGet(k);
                if (null == aModel) {
                    Log.warn((Object[])new Object[]{"Failed to find base model; skipping: " + k});
                    continue;
                }
                if (null == aModel._output._cross_validation_holdout_predictions_frame_id) {
                    throw new H2OIllegalArgumentException("Failed to find the xval predictions frame id. . .  Looks like keep_cross_validation_predictions wasn't set when building the models.");
                }
                Frame aModelsPredictions = (Frame)aModel._output._cross_validation_holdout_predictions_frame_id.get();
                StackedEnsemble.addModelPredictionsToLevelOneFrame(aModel, aModelsPredictions, levelOneFrame);
            }
            levelOneFrame.add(StackedEnsemble.this._model.responseColumn, ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms).train().vec(StackedEnsemble.this._model.responseColumn));
            Frame old = (Frame)DKV.getGet((Key)levelOneFrame._key);
            if (old != null && old instanceof Frame) {
                Frame oldFrame = old;
                oldFrame.removeAll();
                oldFrame.write_lock(StackedEnsemble.this._job);
                oldFrame.update(StackedEnsemble.this._job);
                oldFrame.unlock(StackedEnsemble.this._job);
            }
            levelOneFrame.delete_and_lock(StackedEnsemble.this._job);
            levelOneFrame.unlock(StackedEnsemble.this._job);
            Log.info((Object[])new Object[]{"Finished creating \"level one\" frame for stacking: " + levelOneFrame.toString()});
            return levelOneFrame;
        }

        public void computeImpl() {
            StackedEnsemble.this.init(true);
            if (StackedEnsemble.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder((ModelBuilder)StackedEnsemble.this);
            }
            StackedEnsemble.this._model = new StackedEnsembleModel(StackedEnsemble.this.dest(), (StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms, new StackedEnsembleModel.StackedEnsembleOutput(StackedEnsemble.this));
            StackedEnsemble.this._model.delete_and_lock(StackedEnsemble.this._job);
            StackedEnsemble.this._model.checkAndInheritModelProperties();
            Frame levelOneFrame = this.prepareLevelOneFrame((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms);
            Key metalearnerKey = Key.make((String)("metalearner_" + StackedEnsemble.this._model._key));
            Job job = new Job(metalearnerKey, ModelBuilder.javaName((String)"glm"), "StackingEnsemble metalearner (GLM)");
            GLM metaBuilder = (GLM)ModelBuilder.make((String)"GLM", (Job)job, (Key)metalearnerKey);
            ((GLMModel.GLMParameters)metaBuilder._parms)._non_negative = true;
            ((GLMModel.GLMParameters)metaBuilder._parms)._train = levelOneFrame._key;
            ((GLMModel.GLMParameters)metaBuilder._parms)._response_column = StackedEnsemble.this._model.responseColumn;
            ((GLMModel.GLMParameters)metaBuilder._parms)._family = StackedEnsemble.this._model.modelCategory == ModelCategory.Regression ? GLMModel.GLMParameters.Family.gaussian : GLMModel.GLMParameters.Family.binomial;
            metaBuilder.init(false);
            Job j = metaBuilder.trainModel();
            while (j.isRunning()) {
                try {
                    StackedEnsemble.this._job.update(j._work, "training metalearner");
                    Thread.sleep(100L);
                }
                catch (InterruptedException interruptedException) {}
            }
            Log.info((Object[])new Object[]{"Finished training metalearner model."});
            ((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._metalearner = metaBuilder.get();
            StackedEnsemble.this._model.doScoreMetrics(StackedEnsemble.this._job);
            StackedEnsemble.this._model.update(StackedEnsemble.this._job);
            StackedEnsemble.this._model.unlock(StackedEnsemble.this._job);
        }
    }
}

