/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.automl.modeling;

import ai.h2o.automl.Algo;
import ai.h2o.automl.AutoML;
import ai.h2o.automl.AutoMLBuildSpec;
import ai.h2o.automl.ModelingStep;
import ai.h2o.automl.ModelingSteps;
import ai.h2o.automl.ModelingStepsProvider;
import ai.h2o.automl.WorkAllocations;
import ai.h2o.automl.events.EventLogEntry;
import hex.Model;
import hex.ensemble.StackedEnsembleModel;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.stream.Stream;
import water.Job;
import water.Key;

public class StackedEnsembleStepsProvider
implements ModelingStepsProvider<StackedEnsembleSteps> {
    @Override
    public String getName() {
        return Algo.StackedEnsemble.name();
    }

    @Override
    public StackedEnsembleSteps newInstance(AutoML aml) {
        return new StackedEnsembleSteps(aml);
    }

    public static class StackedEnsembleSteps
    extends ModelingSteps {
        private ModelingStep[] defaults = new StackedEnsembleModelStep[]{new StackedEnsembleModelStep("best", 10, this.aml()){
            {
                super(id, weight, autoML);
                this._description = this._description + " (built using top model from each algorithm type)";
            }

            @Override
            protected Key<Model>[] getBaseModels() {
                ArrayList<Key<Model>> bestModelsOfEachType = new ArrayList<Key<Model>>();
                HashSet<String> typesOfGatheredModels = new HashSet<String>();
                for (Key<Model> key : this.getTrainedModelsKeys()) {
                    String type = this.getModelType(key);
                    if (this.isStackedEnsemble(key) || typesOfGatheredModels.contains(type)) continue;
                    typesOfGatheredModels.add(type);
                    bestModelsOfEachType.add(key);
                }
                return bestModelsOfEachType.toArray(new Key[0]);
            }

            @Override
            protected Job<StackedEnsembleModel> startJob() {
                return this.stack((Object)((Object)this._algo) + "_BestOfFamily", this.getBaseModels(), false);
            }
        }, new StackedEnsembleModelStep("all", 10, this.aml()){
            {
                super(id, weight, autoML);
                this._description = this._description + " (built using all AutoML models)";
            }

            @Override
            protected Key<Model>[] getBaseModels() {
                return (Key[])Stream.of(this.getTrainedModelsKeys()).filter(k -> !this.isStackedEnsemble((Key<Model>)k)).toArray(Key[]::new);
            }

            @Override
            protected Job<StackedEnsembleModel> startJob() {
                return this.stack((Object)((Object)this._algo) + "_AllModels", this.getBaseModels(), true);
            }
        }};
        private ModelingStep[] grids = new ModelingStep[0];

        public StackedEnsembleSteps(AutoML autoML) {
            super(autoML);
        }

        @Override
        protected ModelingStep[] getDefaultModels() {
            return this.defaults;
        }

        @Override
        protected ModelingStep[] getGrids() {
            return this.grids;
        }

        static abstract class StackedEnsembleModelStep
        extends ModelingStep.ModelStep<StackedEnsembleModel> {
            StackedEnsembleModelStep(String id, int weight, AutoML autoML) {
                super(Algo.StackedEnsemble, id, weight, autoML);
                this._ignoreConstraints = true;
            }

            @Override
            protected boolean canRun() {
                Key<Model>[] keys = this.getBaseModels();
                WorkAllocations.Work seWork = this.getAllocatedWork();
                if (seWork == null) {
                    this.aml().job().update(0L, "Skipping this StackedEnsemble");
                    this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, String.format("Skipping StackedEnsemble '%s' due to the exclude_algos option.", this._id));
                    return false;
                }
                if (keys.length == 0) {
                    this.aml().job().update((long)seWork.consume(), "No base models; skipping this StackedEnsemble");
                    this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, String.format("No base models, due to timeouts or the exclude_algos option. Skipping StackedEnsemble '%s'.", this._id));
                    return false;
                }
                if (keys.length == 1) {
                    this.aml().job().update((long)seWork.consume(), "Only one base model; skipping this StackedEnsemble");
                    this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, String.format("Skipping StackedEnsemble '%s' since there is only one model to stack", this._id));
                    return false;
                }
                if (!this.isCVEnabled() && this.aml().getBlendingFrame() == null) {
                    this.aml().job().update((long)seWork.consume(), "Cross-validation disabled by the user and no blending frame provided; Skipping this StackedEnsemble");
                    this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, String.format("Cross-validation is disabled by the user and no blending frame was provided; skipping StackedEnsemble '%s'.", this._id));
                    return false;
                }
                return true;
            }

            protected abstract Key<Model>[] getBaseModels();

            protected String getModelType(Key<Model> key) {
                String keyStr = key.toString();
                return keyStr.substring(0, keyStr.indexOf(95));
            }

            protected boolean isStackedEnsemble(Key<Model> key) {
                return key.toString().startsWith(this._algo.name());
            }

            Job<StackedEnsembleModel> stack(String modelName, Key<Model>[] baseModels, boolean isLast) {
                AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
                StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters = new StackedEnsembleModel.StackedEnsembleParameters();
                stackedEnsembleParameters._base_models = baseModels;
                stackedEnsembleParameters._valid = this.aml().getValidationFrame() == null ? null : this.aml().getValidationFrame()._key;
                stackedEnsembleParameters._blending = this.aml().getBlendingFrame() == null ? null : this.aml().getBlendingFrame()._key;
                stackedEnsembleParameters._keep_levelone_frame = true;
                stackedEnsembleParameters._keep_base_model_predictions = !isLast;
                stackedEnsembleParameters._metalearner_fold_column = buildSpec.input_spec.fold_column;
                stackedEnsembleParameters._metalearner_nfolds = buildSpec.build_control.nfolds;
                stackedEnsembleParameters.initMetalearnerParams();
                stackedEnsembleParameters._metalearner_parameters._keep_cross_validation_models = buildSpec.build_control.keep_cross_validation_models;
                stackedEnsembleParameters._metalearner_parameters._keep_cross_validation_predictions = buildSpec.build_control.keep_cross_validation_predictions;
                Key modelKey = this.makeKey(modelName, false);
                return this.trainModel(modelKey, (Model.Parameters)stackedEnsembleParameters);
            }
        }
    }
}

