/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.automl;

import ai.h2o.automl.Algo;
import ai.h2o.automl.AutoML;
import ai.h2o.automl.AutoMLBuildSpec;
import ai.h2o.automl.IAlgo;
import ai.h2o.automl.ModelSelectionStrategies;
import ai.h2o.automl.ModelSelectionStrategy;
import ai.h2o.automl.ModelingStepsExecutor;
import ai.h2o.automl.Models;
import ai.h2o.automl.StepDefinition;
import ai.h2o.automl.StepResultState;
import ai.h2o.automl.WorkAllocations;
import ai.h2o.automl.events.EventLog;
import ai.h2o.automl.events.EventLogEntry;
import ai.h2o.automl.leaderboard.Leaderboard;
import ai.h2o.automl.preprocessing.PreprocessingConfig;
import ai.h2o.automl.preprocessing.PreprocessingStep;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelContainer;
import hex.ScoreKeeper;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import hex.grid.HyperSpaceWalker;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Predicate;
import jsr166y.CountedCompleter;
import org.apache.commons.lang.builder.ToStringBuilder;
import water.Futures;
import water.H2O;
import water.Iced;
import water.Job;
import water.Key;
import water.Keyed;
import water.exceptions.H2OIllegalArgumentException;
import water.util.ArrayUtils;
import water.util.Countdown;
import water.util.EnumUtils;
import water.util.Log;

public abstract class ModelingStep<M extends Model>
extends Iced<ModelingStep> {
    static Predicate<WorkAllocations.Work> isDefaultModel = w2 -> w2._type == WorkAllocations.JobType.ModelBuild;
    static Predicate<WorkAllocations.Work> isExplorationWork = w2 -> w2._type == WorkAllocations.JobType.ModelBuild || w2._type == WorkAllocations.JobType.HyperparamSearch;
    static Predicate<WorkAllocations.Work> isExploitationWork = w2 -> w2._type == WorkAllocations.JobType.Selection;
    private transient AutoML _aml;
    protected final IAlgo _algo;
    protected final String _provider;
    protected final String _id;
    protected int _weight;
    protected int _priorityGroup;
    protected AutoML.Constraint[] _ignoredConstraints = new AutoML.Constraint[0];
    protected String _description;
    protected WorkAllocations.Work _work;
    private final transient List<Consumer<Job>> _onDone = new ArrayList<Consumer<Job>>();
    StepDefinition _fromDef;
    final transient Predicate<WorkAllocations.Work> _isSamePriorityGroup = w2 -> w2._priorityGroup == this._priorityGroup;

    protected <MP extends Model.Parameters> Job<Grid> startSearch(Key<Grid> resultKey, MP baseParams, Map<String, Object[]> hyperParams, HyperSpaceSearchCriteria searchCriteria) {
        assert (resultKey != null);
        assert (baseParams != null);
        assert (hyperParams.size() > 0);
        assert (searchCriteria != null);
        this.applyPreprocessing(baseParams);
        this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, "AutoML: starting " + resultKey + " hyperparameter search").setNamedValue("start_" + this._provider + "_" + this._id, new Date(), EventLogEntry.epochFormat.get());
        return GridSearch.create(resultKey, HyperSpaceWalker.BaseWalker.WalkerFactory.create(baseParams, hyperParams, new GridSearch.SimpleParametersBuilderFactory(), searchCriteria)).withParallelism(1).withMaxConsecutiveFailures(this.aml()._maxConsecutiveModelFailures).start();
    }

    protected <MP extends Model.Parameters> Job<M> startModel(Key<M> resultKey, MP params) {
        assert (resultKey != null);
        assert (params != null);
        Job<M> job = new Job<M>(resultKey, ModelBuilder.javaName(this._algo.urlName()), this._description);
        this.applyPreprocessing(params);
        Object builder = ModelBuilder.make(this._algo.urlName(), job, resultKey);
        ((ModelBuilder)builder)._parms = params;
        this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, "AutoML: starting " + resultKey + " model training").setNamedValue("start_" + this._provider + "_" + this._id, new Date(), EventLogEntry.epochFormat.get());
        ((ModelBuilder)builder).init(false);
        if (((ModelBuilder)builder)._messages.length > 0) {
            for (ModelBuilder.ValidationMessage vm : ((ModelBuilder)builder)._messages) {
                if (vm.log_level() == 2) {
                    this.aml().eventLog().warn(EventLogEntry.Stage.ModelTraining, vm.field() + " param, " + vm.message());
                    continue;
                }
                if (vm.log_level() != 1) continue;
                this.aml().eventLog().error(EventLogEntry.Stage.ModelTraining, vm.field() + " param, " + vm.message());
            }
        }
        return ((ModelBuilder)builder).trainModelOnH2ONode();
    }

    private boolean validParameters(Model.Parameters parms, String[] fields) {
        try {
            Model.Parameters params = (Model.Parameters)parms.clone();
            this.setCommonModelBuilderParams(params);
            Object mb = ModelBuilder.make(params);
            ((ModelBuilder)mb).init(false);
            return Arrays.stream(fields).allMatch(field -> mb.getMessagesByFieldAndSeverity((String)field, (byte)1).length == 0);
        }
        catch (H2OIllegalArgumentException e2) {
            return false;
        }
    }

    protected void setDistributionParameters(Model.Parameters parms) {
        switch (this.aml().getDistributionFamily()) {
            case custom: {
                parms._custom_distribution_func = this.aml().getBuildSpec().build_control.custom_distribution_func;
                break;
            }
            case huber: {
                parms._huber_alpha = this.aml().getBuildSpec().build_control.huber_alpha;
                break;
            }
            case tweedie: {
                parms._tweedie_power = this.aml().getBuildSpec().build_control.tweedie_power;
                break;
            }
            case quantile: {
                parms._quantile_alpha = this.aml().getBuildSpec().build_control.quantile_alpha;
            }
        }
        try {
            parms.setDistributionFamily(this.aml().getDistributionFamily());
        }
        catch (H2OIllegalArgumentException e2) {
            parms.setDistributionFamily(DistributionFamily.AUTO);
        }
        if (!this.validParameters(parms, new String[]{"_distribution", "_family"})) {
            parms.setDistributionFamily(DistributionFamily.AUTO);
        }
        if (!this.aml().getDistributionFamily().equals((Object)parms.getDistributionFamily())) {
            this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, "Algo " + parms.algoName() + " doesn't support " + this._aml.getDistributionFamily().name() + " distribution. Using AUTO distribution instead.");
        }
    }

    protected ModelingStep(String provider, IAlgo algo, String id, int priorityGroup, int weight, AutoML autoML) {
        assert (priorityGroup >= 0);
        this._provider = provider;
        this._algo = algo;
        this._id = id;
        this._priorityGroup = priorityGroup;
        this._weight = weight;
        this._aml = autoML;
        this._description = provider + " " + id;
    }

    public String getProvider() {
        return this._provider;
    }

    public String getId() {
        return this._id;
    }

    public String getGlobalId() {
        return this._provider + ":" + this._id;
    }

    public IAlgo getAlgo() {
        return this._algo;
    }

    public int getWeight() {
        return this._weight;
    }

    public int getPriorityGroup() {
        return this._priorityGroup;
    }

    public boolean isResumable() {
        return false;
    }

    public boolean ignores(AutoML.Constraint constraint) {
        return ArrayUtils.contains(this._ignoredConstraints, constraint);
    }

    public boolean limitModelTrainingTime() {
        return !this.ignores(AutoML.Constraint.TIMEOUT) && this.aml().getBuildSpec().build_control.stopping_criteria.max_models() == 0;
    }

    public boolean canRun() {
        WorkAllocations.Work work = this.getAllocatedWork();
        return work != null && work._weight > 0;
    }

    public Job run() {
        Job job = this.startJob();
        if (job != null && job._result != null) {
            this.register(job._result);
            if (this.isResumable()) {
                this.aml().session().addResumableKey(job._result);
            }
        }
        return job;
    }

    public Iterator<? extends ModelingStep> iterateSubSteps() {
        return Collections.emptyIterator();
    }

    protected Optional<? extends ModelingStep> getSubStep(String id) {
        return Optional.empty();
    }

    protected abstract WorkAllocations.JobType getJobType();

    protected abstract Job startJob();

    protected void onDone(Job job) {
        for (Consumer<Job> exec : this._onDone) {
            exec.accept(job);
        }
        this._onDone.clear();
    }

    protected void register(Key key) {
        this.aml().session().registerKeySource(key, this);
    }

    protected AutoML aml() {
        return this._aml;
    }

    protected WorkAllocations.Work getAllocatedWork() {
        if (this._work == null) {
            this._work = this.getWorkAllocations().getAllocation(this._id, this._algo);
        }
        return this._work;
    }

    protected WorkAllocations.Work makeWork() {
        return new WorkAllocations.Work(this.getId(), this.getAlgo(), this.getJobType(), this.getPriorityGroup(), this.getWeight());
    }

    protected Key makeKey(String name, boolean withCounter) {
        return this.aml().makeKey(name, null, withCounter);
    }

    protected WorkAllocations getWorkAllocations() {
        return this.aml()._workAllocations;
    }

    protected Model[] getTrainedModels() {
        return this.aml().leaderboard().getModels();
    }

    protected Key<Model>[] getTrainedModelsKeys() {
        return this.aml().leaderboard().getModelKeys();
    }

    protected boolean isCVEnabled() {
        return this.aml().isCVEnabled();
    }

    public boolean equals(Object o2) {
        if (this == o2) {
            return true;
        }
        if (o2 == null || this.getClass() != o2.getClass()) {
            return false;
        }
        ModelingStep that = (ModelingStep)o2;
        return this._provider.equals(that._provider) && this._id.equals(that._id);
    }

    public int hashCode() {
        return Objects.hash(this._provider, this._id);
    }

    protected void setCommonModelBuilderParams(Model.Parameters params) {
        params._train = this.aml()._trainingFrame._key;
        if (null != this.aml()._validationFrame) {
            params._valid = this.aml()._validationFrame._key;
        }
        AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
        params._response_column = buildSpec.input_spec.response_column;
        params._ignored_columns = buildSpec.input_spec.ignored_columns;
        this.setCrossValidationParams(params);
        this.setWeightingParams(params);
        this.setClassBalancingParams(params);
        params._keep_cross_validation_models = buildSpec.build_control.keep_cross_validation_models;
        params._keep_cross_validation_fold_assignment = buildSpec.build_control.nfolds != 0 && buildSpec.build_control.keep_cross_validation_fold_assignment;
        params._export_checkpoints_dir = buildSpec.build_control.export_checkpoints_dir;
        params._main_model_time_budget_factor = 2.0;
    }

    protected void setCrossValidationParams(Model.Parameters params) {
        AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
        params._keep_cross_validation_predictions = this.aml().getBlendingFrame() == null ? true : buildSpec.build_control.keep_cross_validation_predictions;
        params._fold_column = buildSpec.input_spec.fold_column;
        if (buildSpec.input_spec.fold_column == null) {
            params._nfolds = buildSpec.build_control.nfolds;
            if (buildSpec.build_control.nfolds > 1) {
                params._fold_assignment = Model.Parameters.FoldAssignmentScheme.Modulo;
            }
        }
    }

    protected void setWeightingParams(Model.Parameters params) {
        AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
        params._weights_column = buildSpec.input_spec.weights_column;
    }

    protected void setClassBalancingParams(Model.Parameters params) {
        AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
        if (buildSpec.build_control.balance_classes) {
            params._balance_classes = buildSpec.build_control.balance_classes;
            params._class_sampling_factors = buildSpec.build_control.class_sampling_factors;
            params._max_after_balance_size = buildSpec.build_control.max_after_balance_size;
        }
    }

    protected void setCustomParams(Model.Parameters params) {
        AutoMLBuildSpec.AutoMLCustomParameters customParams = this.aml().getBuildSpec().build_models.algo_parameters;
        if (customParams == null) {
            return;
        }
        customParams.applyCustomParameters(this._algo, params);
    }

    protected void applyPreprocessing(Model.Parameters params) {
        if (this.aml().getPreprocessing() == null) {
            return;
        }
        for (PreprocessingStep preprocessingStep : this.aml().getPreprocessing()) {
            PreprocessingStep.Completer complete = preprocessingStep.apply(params, this.getPreprocessingConfig());
            this._onDone.add(j2 -> complete.run());
        }
    }

    protected PreprocessingConfig getPreprocessingConfig() {
        return new PreprocessingConfig();
    }

    protected void setStoppingCriteria(Model.Parameters parms, Model.Parameters defaults) {
        AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
        if (parms._stopping_metric == defaults._stopping_metric) {
            parms._stopping_metric = buildSpec.build_control.stopping_criteria.stopping_metric();
        }
        if (parms._stopping_metric == ScoreKeeper.StoppingMetric.AUTO) {
            ScoreKeeper.StoppingMetric stoppingMetric = parms._stopping_metric = this.aml().getResponseColumn().cardinality() == -1 ? ScoreKeeper.StoppingMetric.deviance : ScoreKeeper.StoppingMetric.logloss;
        }
        if (parms._stopping_rounds == defaults._stopping_rounds) {
            parms._stopping_rounds = buildSpec.build_control.stopping_criteria.stopping_rounds();
        }
        if (parms._stopping_tolerance == defaults._stopping_tolerance) {
            parms._stopping_tolerance = buildSpec.build_control.stopping_criteria.stopping_tolerance();
        }
    }

    protected void setSeed(Model.Parameters parms, Model.Parameters defaults, SeedPolicy seedPolicy) {
        AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
        if (parms._seed == defaults._seed) {
            switch (seedPolicy) {
                case Global: {
                    parms._seed = buildSpec.build_control.stopping_criteria.seed();
                    break;
                }
                case Incremental: {
                    parms._seed = this._aml._incrementalSeed.get() == defaults._seed ? defaults._seed : this._aml._incrementalSeed.getAndIncrement();
                    break;
                }
            }
        }
    }

    protected void initTimeConstraints(Model.Parameters parms, double upperLimit) {
        AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
        if (parms._max_runtime_secs == 0.0) {
            double maxPerModel = buildSpec.build_control.stopping_criteria.max_runtime_secs_per_model();
            parms._max_runtime_secs = upperLimit <= 0.0 ? maxPerModel : Math.min(maxPerModel, upperLimit);
        }
    }

    private String getSortMetric() {
        Leaderboard leaderboard = this.aml().leaderboard();
        return leaderboard == null ? null : leaderboard.getSortMetric();
    }

    private static ScoreKeeper.StoppingMetric metricValueOf(String name) {
        if (name == null) {
            return ScoreKeeper.StoppingMetric.AUTO;
        }
        switch (name) {
            case "mean_residual_deviance": {
                return ScoreKeeper.StoppingMetric.deviance;
            }
        }
        try {
            return EnumUtils.valueOf(ScoreKeeper.StoppingMetric.class, name);
        }
        catch (IllegalArgumentException illegalArgumentException) {
            return ScoreKeeper.StoppingMetric.AUTO;
        }
    }

    public static abstract class DynamicStep<M extends Model>
    extends ModelingStep<M> {
        public static final int DEFAULT_DYNAMIC_TRAINING_WEIGHT = 20;
        public static final int DEFAULT_DYNAMIC_GROUP = 100;
        private transient Collection<ModelingStep> _subSteps;

        public DynamicStep(String provider, String id, AutoML autoML) {
            this(provider, id, 100, 20, autoML);
        }

        public DynamicStep(String provider, String id, int priorityGroup, int weight, AutoML autoML) {
            super(provider, new VirtualAlgo(), id, priorityGroup, weight, autoML);
        }

        @Override
        public boolean canRun() {
            return false;
        }

        @Override
        protected Job<M> startJob() {
            return null;
        }

        @Override
        protected WorkAllocations.JobType getJobType() {
            return WorkAllocations.JobType.Dynamic;
        }

        @Override
        protected Key<Models> makeKey(String name, boolean withCounter) {
            return this.aml().makeKey(name, "decision", withCounter);
        }

        private void initSubSteps() {
            if (this._subSteps == null) {
                this._subSteps = this.prepareModelingSteps();
            }
        }

        @Override
        public Iterator<? extends ModelingStep> iterateSubSteps() {
            this.initSubSteps();
            return this._subSteps.iterator();
        }

        @Override
        protected Optional<? extends ModelingStep> getSubStep(String id) {
            this.initSubSteps();
            return this._subSteps.stream().filter(step -> step._id.equals(id)).findFirst();
        }

        protected abstract Collection<ModelingStep> prepareModelingSteps();

        public static class VirtualAlgo
        implements IAlgo {
            @Override
            public String name() {
                return "virtual";
            }
        }
    }

    public static abstract class SelectionStep<M extends Model>
    extends ModelingStep<M> {
        public static final int DEFAULT_SELECTION_TRAINING_WEIGHT = 20;
        public static final int DEFAULT_SELECTION_GROUP = 3;

        public SelectionStep(String provider, IAlgo algo, String id, AutoML autoML) {
            this(provider, algo, id, 3, 20, autoML);
        }

        public SelectionStep(String provider, IAlgo algo, String id, int priorityGroup, int weight, AutoML autoML) {
            super(provider, algo, id, priorityGroup, weight, autoML);
        }

        @Override
        protected WorkAllocations.JobType getJobType() {
            return WorkAllocations.JobType.Selection;
        }

        @Override
        protected Key<Models> makeKey(String name, boolean withCounter) {
            return this.aml().makeKey(name, "selection", withCounter);
        }

        private ModelSelectionStrategies.LeaderboardHolder makeLeaderboard(String name, final EventLog eventLog) {
            Leaderboard amlLeaderboard = this.aml().leaderboard();
            final EventLog tmpEventLog = eventLog == null ? EventLog.getOrMake(Key.make(name)) : eventLog;
            final Leaderboard tmpLeaderboard = Leaderboard.getOrMake(name, tmpEventLog, amlLeaderboard.leaderboardFrame(), amlLeaderboard.getSortMetric());
            return new ModelSelectionStrategies.LeaderboardHolder(){

                @Override
                public Leaderboard get() {
                    return tmpLeaderboard;
                }

                @Override
                public void cleanup() {
                    tmpLeaderboard.removeModels(tmpLeaderboard.getModelKeys(), false);
                    tmpLeaderboard.remove(false);
                    if (eventLog == null) {
                        tmpEventLog.remove();
                    }
                }
            };
        }

        protected ModelSelectionStrategies.LeaderboardHolder makeTmpLeaderboard(String name) {
            return this.makeLeaderboard("tmp_" + name, null);
        }

        @Override
        protected Job<Models> startJob() {
            final Key[] trainedModelKeys = this.getTrainedModelsKeys();
            final Key<Models> key = this.makeKey(this._provider + "_" + this._id, false);
            this.aml().trackKeys(key);
            final Job<Models> job = new Job<Models>(key, Models.class.getName(), this._description);
            WorkAllocations.Work work = this.getAllocatedWork();
            final double maxAssignedTimeSecs = this.limitModelTrainingTime() ? (double)((float)this.aml().timeRemainingMs() * this.getWorkAllocations().remainingWorkRatio(work)) / 1000.0 : 0.0;
            this.aml().eventLog().debug(EventLogEntry.Stage.ModelTraining, maxAssignedTimeSecs == 0.0 ? "No time limitation for " + key : "Time assigned for " + key + ": " + maxAssignedTimeSecs + "s");
            return job.start(new H2O.H2OCountedCompleter(){
                Models result;
                Key<Models> selectionKey;
                EventLog selectionEventLog;
                ModelSelectionStrategies.LeaderboardHolder selectionLeaderboard;
                {
                    this.result = new Models<Model>(key, Model.class, job);
                    this.selectionKey = Key.make(key + "_select");
                    this.selectionEventLog = EventLog.getOrMake(this.selectionKey);
                    this.selectionLeaderboard = this.makeLeaderboard(this.selectionKey.toString(), this.selectionEventLog);
                    this.result.delete_and_lock(job);
                }

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                @Override
                public void compute2() {
                    Countdown countdown = Countdown.fromSeconds(maxAssignedTimeSecs);
                    ModelSelectionStrategy.Selection<Model> selection = null;
                    try {
                        ModelingStepsExecutor localExecutor = new ModelingStepsExecutor(this.selectionLeaderboard.get(), this.selectionEventLog, countdown);
                        localExecutor.start();
                        Job<Models> innerTraining = this.startTraining(this.selectionKey, maxAssignedTimeSecs);
                        StepResultState state = localExecutor.monitor(innerTraining, this, job);
                        if (state.is(StepResultState.ResultStatus.success)) {
                            Log.debug("Selection leaderboard " + this.selectionLeaderboard.get()._key, this.selectionLeaderboard.get().toLogString());
                            selection = this.getSelectionStrategy().select(trainedModelKeys, this.selectionLeaderboard.get().getModelKeys());
                            Leaderboard lb = this.aml().leaderboard();
                            Log.debug("Selection result for job " + key, ToStringBuilder.reflectionToString(selection));
                            lb.removeModels(selection._remove, false);
                            this.aml().trackKeys(selection._remove);
                            lb.addModels(selection._add);
                        } else {
                            if (state.is(StepResultState.ResultStatus.failed)) {
                                throw (RuntimeException)state.error();
                            }
                            if (state.is(StepResultState.ResultStatus.cancelled)) {
                                throw new Job.JobCancelledException();
                            }
                        }
                    }
                    finally {
                        this.result.unlock(job);
                        if (selection != null) {
                            this.result.addModels(selection._add);
                        }
                    }
                    this.tryComplete();
                }

                @Override
                public void onCompletion(CountedCompleter caller) {
                    Keyed.remove(this.selectionKey, new Futures(), false);
                    this.selectionLeaderboard.get().removeModels(trainedModelKeys, false);
                    this.selectionLeaderboard.get().removeModels((Key[])Arrays.stream(this.selectionLeaderboard.get().getModelKeys()).filter(k2 -> !ArrayUtils.contains(this.result.getModelKeys(), k2)).toArray(Key[]::new), true);
                    this.selectionLeaderboard.cleanup();
                    if (!this.aml().eventLog()._key.equals(this.selectionEventLog._key)) {
                        this.selectionEventLog.remove();
                    }
                    super.onCompletion(caller);
                }

                @Override
                public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) {
                    this.result.unlock(job._key, false);
                    Keyed.remove(this.selectionKey);
                    this.selectionLeaderboard.get().remove();
                    if (!this.aml().eventLog()._key.equals(this.selectionEventLog._key)) {
                        this.selectionEventLog.remove();
                    }
                    return super.onExceptionalCompletion(ex, caller);
                }
            }, work._weight, maxAssignedTimeSecs);
        }

        protected abstract Job<Models> startTraining(Key<Models> var1, double var2);

        protected abstract ModelSelectionStrategy getSelectionStrategy();

        protected Job<Models> asModelsJob(final Job job, final Key<Models> result) {
            final Job<Models> jModels = new Job<Models>(result, Models.class.getName(), job._description);
            return jModels.start(new H2O.H2OCountedCompleter(){
                Models models;
                {
                    this.models = new Models<Model>(result, Model.class, jModels);
                    this.models.delete_and_lock(jModels);
                }

                @Override
                public void compute2() {
                    ModelingStepsExecutor.ensureStopRequestPropagated(job, jModels, 1000);
                    Object res = job.get();
                    this.models.unlock(jModels);
                    if (res instanceof Model) {
                        this.models.addModel(((Keyed)res).getKey());
                    } else if (res instanceof ModelContainer) {
                        this.models.addModels(((ModelContainer)res).getModelKeys());
                        ((Keyed)res).remove(false);
                    } else {
                        throw new H2OIllegalArgumentException("Can only convert jobs producing a single Model or ModelContainer.");
                    }
                    this.tryComplete();
                }
            }, job._work, job._max_runtime_msecs);
        }
    }

    public static abstract class GridStep<M extends Model>
    extends ModelingStep<M> {
        public static final int DEFAULT_GRID_TRAINING_WEIGHT = 30;
        public static final int DEFAULT_GRID_GROUP = 2;
        protected static final int GRID_STOPPING_ROUND_FACTOR = 2;

        public GridStep(String provider, IAlgo algo, String id, AutoML autoML) {
            this(provider, algo, id, 2, 30, autoML);
        }

        public GridStep(String provider, IAlgo algo, String id, int priorityGroup, int weight, AutoML autoML) {
            super(provider, algo, id, priorityGroup, weight, autoML);
        }

        @Override
        protected WorkAllocations.JobType getJobType() {
            return WorkAllocations.JobType.HyperparamSearch;
        }

        @Override
        public boolean isResumable() {
            return true;
        }

        public abstract Model.Parameters prepareModelParameters();

        public abstract Map<String, Object[]> prepareSearchParameters();

        @Override
        protected Job<Grid> startJob() {
            return this.hyperparameterSearch(this.prepareModelParameters(), this.prepareSearchParameters());
        }

        @Override
        protected Key<Grid> makeKey(String name, boolean withCounter) {
            return this.aml().makeKey(name, "grid", withCounter);
        }

        protected Job<Grid> hyperparameterSearch(Model.Parameters baseParms, Map<String, Object[]> searchParms) {
            return this.hyperparameterSearch(null, baseParms, searchParms);
        }

        protected Job<Grid> hyperparameterSearch(Key<Grid> key, Model.Parameters baseParms, Map<String, Object[]> searchParms) {
            Model.Parameters defaults;
            try {
                defaults = (Model.Parameters)baseParms.getClass().newInstance();
            }
            catch (Exception e2) {
                this.aml().eventLog().error(EventLogEntry.Stage.ModelTraining, "Internal error doing hyperparameter search");
                throw new H2OIllegalArgumentException("Hyperparameter search can't create a new instance of Model.Parameters subclass: " + baseParms.getClass());
            }
            this.initTimeConstraints(baseParms, 0.0);
            this.setCommonModelBuilderParams(baseParms);
            this.setStoppingCriteria(baseParms, defaults);
            this.setCustomParams(baseParms);
            this.setDistributionParameters(baseParms);
            AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
            HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria searchCriteria = (HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria)buildSpec.build_control.stopping_criteria.getSearchCriteria().clone();
            this.setSearchCriteria(searchCriteria, baseParms);
            if (null == key) {
                key = this.makeKey(this._provider, true);
            }
            this.aml().trackKeys(key);
            Log.debug("Hyperparameter search: " + this._provider + ", time remaining (ms): " + this.aml().timeRemainingMs());
            this.aml().eventLog().debug(EventLogEntry.Stage.ModelTraining, searchCriteria.max_runtime_secs() == 0.0 ? "No time limitation for " + key : "Time assigned for " + key + ": " + searchCriteria.max_runtime_secs() + "s");
            return this.startSearch(key, baseParms, searchParms, searchCriteria);
        }

        protected void setSearchCriteria(HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria searchCriteria, Model.Parameters baseParms) {
            WorkAllocations.Work work = this.getAllocatedWork();
            double maxAssignedTimeSecs = this.limitModelTrainingTime() ? (double)((float)this.aml().timeRemainingMs() * this.getWorkAllocations().remainingWorkRatio(work, this._isSamePriorityGroup)) / 1000.0 : 0.0;
            int maxAssignedModels = (int)Math.ceil((float)this.aml().remainingModels() * this.getWorkAllocations().remainingWorkRatio(work, isExplorationWork.and(w2 -> w2._algo != Algo.StackedEnsemble)));
            searchCriteria.set_max_runtime_secs(searchCriteria.max_runtime_secs() == 0.0 ? maxAssignedTimeSecs : Math.min(searchCriteria.max_runtime_secs(), maxAssignedTimeSecs));
            searchCriteria.set_max_models(searchCriteria.max_models() == 0 ? maxAssignedModels : Math.min(searchCriteria.max_models(), maxAssignedModels));
            searchCriteria.set_stopping_rounds(baseParms._stopping_rounds * 2);
        }
    }

    public static abstract class ModelStep<M extends Model>
    extends ModelingStep<M> {
        public static final int DEFAULT_MODEL_TRAINING_WEIGHT = 10;
        public static final int DEFAULT_MODEL_GROUP = 1;

        public ModelStep(String provider, IAlgo algo, String id, AutoML autoML) {
            this(provider, algo, id, 1, 10, autoML);
        }

        public ModelStep(String provider, IAlgo algo, String id, int priorityGroup, int weight, AutoML autoML) {
            super(provider, algo, id, priorityGroup, weight, autoML);
        }

        @Override
        protected WorkAllocations.JobType getJobType() {
            return WorkAllocations.JobType.ModelBuild;
        }

        public abstract Model.Parameters prepareModelParameters();

        @Override
        protected Job<M> startJob() {
            return this.trainModel(this.prepareModelParameters());
        }

        protected Job<M> trainModel(Model.Parameters parms) {
            return this.trainModel(null, parms);
        }

        protected Job<M> trainModel(Key<M> key, Model.Parameters parms) {
            String algoName = ModelBuilder.algoName(this._algo.urlName());
            if (null == key) {
                key = this.makeKey(algoName, true);
            }
            Object defaults = ((ModelBuilder)ModelBuilder.make((String)this._algo.urlName(), null, null))._parms;
            this.initTimeConstraints(parms, 0.0);
            this.setCommonModelBuilderParams(parms);
            this.setSeed(parms, (Model.Parameters)defaults, SeedPolicy.Incremental);
            this.setStoppingCriteria(parms, (Model.Parameters)defaults);
            this.setCustomParams(parms);
            this.setDistributionParameters(parms);
            if (this.limitModelTrainingTime()) {
                WorkAllocations.Work work = this.getAllocatedWork();
                double maxAssignedTimeSecs = (double)((float)this.aml().timeRemainingMs() * this.getWorkAllocations().remainingWorkRatio(work, this._isSamePriorityGroup)) / 1000.0;
                parms._max_runtime_secs = parms._max_runtime_secs == 0.0 ? maxAssignedTimeSecs : Math.min(parms._max_runtime_secs, maxAssignedTimeSecs);
            } else {
                parms._max_runtime_secs = 0.0;
            }
            Log.debug("Training model: " + algoName + ", time remaining (ms): " + this.aml().timeRemainingMs());
            this.aml().eventLog().debug(EventLogEntry.Stage.ModelTraining, parms._max_runtime_secs == 0.0 ? "No time limitation for " + key : "Time assigned for " + key + ": " + parms._max_runtime_secs + "s");
            return this.startModel(key, parms);
        }
    }

    protected static enum SeedPolicy {
        None,
        Global,
        Incremental;

    }
}

