/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.automl;

import ai.h2o.automl.Algo;
import ai.h2o.automl.AutoML;
import ai.h2o.automl.AutoMLBuildSpec;
import ai.h2o.automl.IAlgo;
import ai.h2o.automl.ModelSelectionStrategies;
import ai.h2o.automl.ModelSelectionStrategy;
import ai.h2o.automl.ModelingStepsExecutor;
import ai.h2o.automl.Models;
import ai.h2o.automl.StepDefinition;
import ai.h2o.automl.WorkAllocations;
import ai.h2o.automl.events.EventLog;
import ai.h2o.automl.events.EventLogEntry;
import ai.h2o.automl.leaderboard.Leaderboard;
import ai.h2o.automl.preprocessing.PreprocessingConfig;
import ai.h2o.automl.preprocessing.PreprocessingStep;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelContainer;
import hex.ModelParametersBuilderFactory;
import hex.ScoreKeeper;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Predicate;
import jsr166y.CountedCompleter;
import org.apache.commons.lang.builder.ToStringBuilder;
import water.Futures;
import water.H2O;
import water.Iced;
import water.Job;
import water.Key;
import water.Keyed;
import water.exceptions.H2OIllegalArgumentException;
import water.util.ArrayUtils;
import water.util.Countdown;
import water.util.EnumUtils;
import water.util.Log;

public abstract class ModelingStep<M extends Model>
extends Iced<ModelingStep> {
    static Predicate<WorkAllocations.Work> isDefaultModel = w -> w._type == WorkAllocations.JobType.ModelBuild;
    static Predicate<WorkAllocations.Work> isExplorationWork = w -> w._type == WorkAllocations.JobType.ModelBuild || w._type == WorkAllocations.JobType.HyperparamSearch;
    static Predicate<WorkAllocations.Work> isExploitationWork = w -> w._type == WorkAllocations.JobType.Selection;
    private transient AutoML _aml;
    protected final IAlgo _algo;
    protected final String _id;
    protected int _weight;
    protected AutoML.Constraint[] _ignoredConstraints = new AutoML.Constraint[0];
    protected String _description;
    private final transient List<Consumer<Job>> _onDone = new ArrayList<Consumer<Job>>();
    StepDefinition _fromDef;

    protected <MP extends Model.Parameters> Job<Grid> startSearch(Key<Grid> resultKey, MP baseParams, Map<String, Object[]> hyperParams, HyperSpaceSearchCriteria searchCriteria) {
        this.applyPreprocessing(baseParams);
        this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, "AutoML: starting " + resultKey + " hyperparameter search").setNamedValue("start_" + this._algo + "_" + this._id, new Date(), EventLogEntry.epochFormat.get());
        return GridSearch.startGridSearch(resultKey, baseParams, hyperParams, (ModelParametersBuilderFactory)new GridSearch.SimpleParametersBuilderFactory(), (HyperSpaceSearchCriteria)searchCriteria, (int)1);
    }

    protected <MP extends Model.Parameters> Job<M> startModel(Key<M> resultKey, MP params) {
        Job job = new Job(resultKey, ModelBuilder.javaName((String)this._algo.urlName()), this._description);
        this.applyPreprocessing(params);
        ModelBuilder builder = ModelBuilder.make((String)this._algo.urlName(), (Job)job, resultKey);
        builder._parms = params;
        this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, "AutoML: starting " + resultKey + " model training").setNamedValue("start_" + this._algo + "_" + this._id, new Date(), EventLogEntry.epochFormat.get());
        try {
            builder.init(false);
            return builder.trainModelOnH2ONode();
        }
        catch (H2OIllegalArgumentException exception) {
            this.aml().eventLog().warn(EventLogEntry.Stage.ModelTraining, "Skipping training of model " + resultKey + " due to exception: " + (Object)((Object)exception));
            this.onDone(null);
            return null;
        }
    }

    protected ModelingStep(IAlgo algo, String id, int weight, AutoML autoML) {
        this._algo = algo;
        this._id = id;
        this._weight = weight;
        this._aml = autoML;
        this._description = algo.name() + " " + id;
    }

    protected abstract WorkAllocations.Work getAllocatedWork();

    protected abstract Key makeKey(String var1, boolean var2);

    protected abstract WorkAllocations.Work makeWork();

    protected abstract Job startJob();

    protected void onDone(Job job) {
        for (Consumer<Job> exec : this._onDone) {
            exec.accept(job);
        }
        this._onDone.clear();
    }

    protected AutoML aml() {
        return this._aml;
    }

    protected boolean canRun() {
        WorkAllocations.Work work = this.getAllocatedWork();
        return work != null && work._weight > 0;
    }

    protected WorkAllocations getWorkAllocations() {
        return this.aml()._workAllocations;
    }

    protected Model[] getTrainedModels() {
        return this.aml().leaderboard().getModels();
    }

    protected Key<Model>[] getTrainedModelsKeys() {
        return this.aml().leaderboard().getModelKeys();
    }

    protected boolean isCVEnabled() {
        return this.aml().isCVEnabled();
    }

    protected void setCommonModelBuilderParams(Model.Parameters params) {
        params._train = this.aml()._trainingFrame._key;
        if (null != this.aml()._validationFrame) {
            params._valid = this.aml()._validationFrame._key;
        }
        AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
        params._response_column = buildSpec.input_spec.response_column;
        params._ignored_columns = buildSpec.input_spec.ignored_columns;
        this.setCrossValidationParams(params);
        this.setWeightingParams(params);
        this.setClassBalancingParams(params);
        params._keep_cross_validation_models = buildSpec.build_control.keep_cross_validation_models;
        params._keep_cross_validation_fold_assignment = buildSpec.build_control.nfolds != 0 && buildSpec.build_control.keep_cross_validation_fold_assignment;
        params._export_checkpoints_dir = buildSpec.build_control.export_checkpoints_dir;
    }

    protected void setCrossValidationParams(Model.Parameters params) {
        AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
        params._keep_cross_validation_predictions = this.aml().getBlendingFrame() == null ? true : buildSpec.build_control.keep_cross_validation_predictions;
        params._fold_column = buildSpec.input_spec.fold_column;
        if (buildSpec.input_spec.fold_column == null) {
            params._nfolds = buildSpec.build_control.nfolds;
            if (buildSpec.build_control.nfolds > 1) {
                params._fold_assignment = Model.Parameters.FoldAssignmentScheme.Modulo;
            }
        }
    }

    protected void setWeightingParams(Model.Parameters params) {
        AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
        params._weights_column = buildSpec.input_spec.weights_column;
    }

    protected void setClassBalancingParams(Model.Parameters params) {
        AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
        if (buildSpec.build_control.balance_classes) {
            params._balance_classes = buildSpec.build_control.balance_classes;
            params._class_sampling_factors = buildSpec.build_control.class_sampling_factors;
            params._max_after_balance_size = buildSpec.build_control.max_after_balance_size;
        }
    }

    protected void setCustomParams(Model.Parameters params) {
        AutoMLBuildSpec.AutoMLCustomParameters customParams = this.aml().getBuildSpec().build_models.algo_parameters;
        if (customParams == null) {
            return;
        }
        customParams.applyCustomParameters(this._algo, params);
    }

    protected void applyPreprocessing(Model.Parameters params) {
        if (this.aml().getPreprocessing() == null) {
            return;
        }
        for (PreprocessingStep preprocessingStep : this.aml().getPreprocessing()) {
            PreprocessingStep.Completer complete = preprocessingStep.apply(params, this.getPreprocessingConfig());
            this._onDone.add(j -> complete.run());
        }
    }

    protected PreprocessingConfig getPreprocessingConfig() {
        return new PreprocessingConfig();
    }

    protected void setStoppingCriteria(Model.Parameters parms, Model.Parameters defaults) {
        AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
        if (parms._stopping_metric == defaults._stopping_metric) {
            parms._stopping_metric = buildSpec.build_control.stopping_criteria.stopping_metric();
        }
        if (parms._stopping_metric == ScoreKeeper.StoppingMetric.AUTO) {
            String sort_metric = this.getSortMetric();
            ScoreKeeper.StoppingMetric stoppingMetric = sort_metric == null ? ScoreKeeper.StoppingMetric.AUTO : (parms._stopping_metric = sort_metric.equals("auc") ? ScoreKeeper.StoppingMetric.logloss : ModelingStep.metricValueOf(sort_metric));
        }
        if (parms._stopping_rounds == defaults._stopping_rounds) {
            parms._stopping_rounds = buildSpec.build_control.stopping_criteria.stopping_rounds();
        }
        if (parms._stopping_tolerance == defaults._stopping_tolerance) {
            parms._stopping_tolerance = buildSpec.build_control.stopping_criteria.stopping_tolerance();
        }
    }

    protected void setSeed(Model.Parameters parms, Model.Parameters defaults, SeedPolicy seedPolicy) {
        AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
        if (parms._seed == defaults._seed) {
            switch (seedPolicy) {
                case Global: {
                    parms._seed = buildSpec.build_control.stopping_criteria.seed();
                    break;
                }
                case Incremental: {
                    parms._seed = this._aml._incrementalSeed.get() == defaults._seed ? defaults._seed : this._aml._incrementalSeed.getAndIncrement();
                    break;
                }
            }
        }
    }

    protected void initTimeConstraints(Model.Parameters parms, double upperLimit) {
        AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
        if (parms._max_runtime_secs == 0.0) {
            double maxPerModel = buildSpec.build_control.stopping_criteria.max_runtime_secs_per_model();
            parms._max_runtime_secs = upperLimit <= 0.0 ? maxPerModel : Math.min(maxPerModel, upperLimit);
        }
    }

    private String getSortMetric() {
        Leaderboard leaderboard = this.aml().leaderboard();
        return leaderboard == null ? null : leaderboard.getSortMetric();
    }

    private static ScoreKeeper.StoppingMetric metricValueOf(String name) {
        if (name == null) {
            return ScoreKeeper.StoppingMetric.AUTO;
        }
        switch (name) {
            case "mean_residual_deviance": {
                return ScoreKeeper.StoppingMetric.deviance;
            }
        }
        try {
            return (ScoreKeeper.StoppingMetric)EnumUtils.valueOf(ScoreKeeper.StoppingMetric.class, (String)name);
        }
        catch (IllegalArgumentException illegalArgumentException) {
            return ScoreKeeper.StoppingMetric.AUTO;
        }
    }

    public static abstract class SelectionStep<M extends Model>
    extends ModelingStep<M> {
        public SelectionStep(IAlgo algo, String id, int weight, AutoML autoML) {
            super(algo, id, weight, autoML);
            this._ignoredConstraints = new AutoML.Constraint[]{AutoML.Constraint.MODEL_COUNT};
        }

        @Override
        protected WorkAllocations.Work makeWork() {
            return new WorkAllocations.Work(this._id, this._algo, WorkAllocations.JobType.Selection, this._weight);
        }

        @Override
        protected WorkAllocations.Work getAllocatedWork() {
            return this.getWorkAllocations().getAllocation(this._id, this._algo);
        }

        @Override
        protected Key<Models> makeKey(String name, boolean withCounter) {
            return this.aml().makeKey(name, "selection", withCounter);
        }

        private ModelSelectionStrategies.LeaderboardHolder makeLeaderboard(String name, final EventLog eventLog) {
            Leaderboard amlLeaderboard = this.aml().leaderboard();
            final EventLog tmpEventLog = eventLog == null ? EventLog.getOrMake(Key.make((String)name)) : eventLog;
            final Leaderboard tmpLeaderboard = Leaderboard.getOrMake(name, tmpEventLog, amlLeaderboard.leaderboardFrame(), amlLeaderboard.getSortMetric());
            return new ModelSelectionStrategies.LeaderboardHolder(){

                @Override
                public Leaderboard get() {
                    return tmpLeaderboard;
                }

                @Override
                public void cleanup() {
                    tmpLeaderboard.removeModels(tmpLeaderboard.getModelKeys(), false);
                    tmpLeaderboard.remove(false);
                    if (eventLog == null) {
                        tmpEventLog.remove();
                    }
                }
            };
        }

        protected ModelSelectionStrategies.LeaderboardHolder makeTmpLeaderboard(String name) {
            return this.makeLeaderboard("tmp_" + name, null);
        }

        @Override
        protected Job<Models> startJob() {
            final Key<Model>[] trainedModelKeys = this.getTrainedModelsKeys();
            final Key<Models> key = this.makeKey(this._algo + "_" + this._id, false);
            this.aml().trackKey(key);
            final Job job = new Job(key, Models.class.getName(), this._description);
            final WorkAllocations.Work work = this.getAllocatedWork();
            final double maxAssignedTimeSecs = ArrayUtils.contains((Object[])this._ignoredConstraints, (Object)((Object)AutoML.Constraint.TIMEOUT)) ? 0.0 : (double)((float)this.aml().timeRemainingMs() * this.getWorkAllocations().remainingWorkRatio(work)) / 1000.0;
            this.aml().eventLog().debug(EventLogEntry.Stage.ModelTraining, maxAssignedTimeSecs == 0.0 ? "No time limitation for " + key : "Time assigned for " + key + ": " + maxAssignedTimeSecs + "s");
            return job.start(new H2O.H2OCountedCompleter(){
                Models result;
                Key<Models> selectionKey;
                EventLog selectionEventLog;
                ModelSelectionStrategies.LeaderboardHolder selectionLeaderboard;
                {
                    this.result = new Models<Model>(key, Model.class, job);
                    this.selectionKey = Key.make((String)(key + "_select"));
                    this.selectionEventLog = EventLog.getOrMake(this.selectionKey);
                    this.selectionLeaderboard = this.makeLeaderboard(this.selectionKey.toString(), this.selectionEventLog);
                    this.result.delete_and_lock(job);
                }

                public void compute2() {
                    Countdown countdown = Countdown.fromSeconds((double)maxAssignedTimeSecs);
                    ModelingStepsExecutor localExecutor = new ModelingStepsExecutor(this.selectionLeaderboard.get(), this.selectionEventLog, countdown);
                    localExecutor.start();
                    Job<Models> innerTraining = this.startTraining(this.selectionKey, maxAssignedTimeSecs);
                    localExecutor.monitor(innerTraining, work, job, false);
                    Log.debug((Object[])new Object[]{"Selection leaderboard " + this.selectionLeaderboard.get()._key, this.selectionLeaderboard.get().toLogString()});
                    ModelSelectionStrategy.Selection<Model> selection = this.getSelectionStrategy().select(trainedModelKeys, this.selectionLeaderboard.get().getModelKeys());
                    Leaderboard lb = this.aml().leaderboard();
                    Log.debug((Object[])new Object[]{"Selection result for job " + key, ToStringBuilder.reflectionToString(selection)});
                    lb.removeModels(selection._remove, true);
                    lb.addModels(selection._add);
                    this.result.unlock(job);
                    this.result.addModels(selection._add);
                    this.tryComplete();
                }

                public void onCompletion(CountedCompleter caller) {
                    Keyed.remove(this.selectionKey, (Futures)new Futures(), (boolean)false);
                    this.selectionLeaderboard.get().removeModels(trainedModelKeys, false);
                    this.selectionLeaderboard.get().removeModels((Key[])Arrays.stream(this.selectionLeaderboard.get().getModelKeys()).filter(k -> !ArrayUtils.contains((Object[])this.result.getModelKeys(), (Object)k)).toArray(Key[]::new), true);
                    this.selectionLeaderboard.cleanup();
                    if (!this.aml().eventLog()._key.equals((Object)this.selectionEventLog._key)) {
                        this.selectionEventLog.remove();
                    }
                    super.onCompletion(caller);
                }

                public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) {
                    this.result.unlock(job._key, false);
                    Keyed.remove(this.selectionKey);
                    this.selectionLeaderboard.get().remove();
                    if (!this.aml().eventLog()._key.equals((Object)this.selectionEventLog._key)) {
                        this.selectionEventLog.remove();
                    }
                    return super.onExceptionalCompletion(ex, caller);
                }
            }, (long)work._weight, maxAssignedTimeSecs);
        }

        protected abstract Job<Models> startTraining(Key<Models> var1, double var2);

        protected abstract ModelSelectionStrategy getSelectionStrategy();

        protected Job<Models> asModelsJob(final Job job, final Key<Models> result) {
            final Job jModels = new Job(result, Models.class.getName(), job._description);
            return jModels.start(new H2O.H2OCountedCompleter(){
                Models models;
                {
                    this.models = new Models<Model>(result, Model.class, jModels);
                    this.models.delete_and_lock(jModels);
                }

                public void compute2() {
                    ModelingStepsExecutor.ensureStopRequestPropagated(job, jModels);
                    Keyed res = job.get();
                    this.models.unlock(jModels);
                    if (res instanceof Model) {
                        this.models.addModel(((Model)res)._key);
                    } else if (res instanceof ModelContainer) {
                        this.models.addModels(((ModelContainer)res).getModelKeys());
                        res.remove(false);
                    } else {
                        throw new H2OIllegalArgumentException("Can only convert jobs producing a single Model or ModelContainer.");
                    }
                    this.tryComplete();
                }
            }, job._work, (double)job._max_runtime_msecs);
        }
    }

    public static abstract class GridStep<M extends Model>
    extends ModelingStep<M> {
        public static final int DEFAULT_GRID_TRAINING_WEIGHT = 20;

        public GridStep(IAlgo algo, String id, int cost, AutoML autoML) {
            super(algo, id, cost, autoML);
        }

        @Override
        protected abstract Job<Grid> startJob();

        @Override
        protected WorkAllocations.Work makeWork() {
            return new WorkAllocations.Work(this._id, this._algo, WorkAllocations.JobType.HyperparamSearch, this._weight);
        }

        @Override
        protected WorkAllocations.Work getAllocatedWork() {
            return this.getWorkAllocations().getAllocation(this._id, this._algo);
        }

        @Override
        protected Key<Grid> makeKey(String name, boolean withCounter) {
            return this.aml().makeKey(name, "grid", withCounter);
        }

        protected Job<Grid> hyperparameterSearch(Model.Parameters baseParms, Map<String, Object[]> searchParms) {
            return this.hyperparameterSearch(null, baseParms, searchParms);
        }

        protected Job<Grid> hyperparameterSearch(Key<Grid> key, Model.Parameters baseParms, Map<String, Object[]> searchParms) {
            Model.Parameters defaults;
            try {
                defaults = (Model.Parameters)baseParms.getClass().newInstance();
            }
            catch (Exception e) {
                this.aml().eventLog().warn(EventLogEntry.Stage.ModelTraining, "Internal error doing hyperparameter search");
                throw new H2OIllegalArgumentException("Hyperparameter search can't create a new instance of Model.Parameters subclass: " + baseParms.getClass());
            }
            this.initTimeConstraints(baseParms, 0.0);
            this.setCommonModelBuilderParams(baseParms);
            this.setStoppingCriteria(baseParms, defaults);
            this.setCustomParams(baseParms);
            AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
            HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria searchCriteria = (HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria)buildSpec.build_control.stopping_criteria.getSearchCriteria().clone();
            WorkAllocations.Work work = this.getAllocatedWork();
            double maxAssignedTimeSecs = ArrayUtils.contains((Object[])this._ignoredConstraints, (Object)((Object)AutoML.Constraint.TIMEOUT)) ? 0.0 : (double)((float)this.aml().timeRemainingMs() * this.getWorkAllocations().remainingWorkRatio(work)) / 1000.0;
            int maxAssignedModels = (int)Math.ceil((float)this.aml().remainingModels() * this.getWorkAllocations().remainingWorkRatio(work, isExplorationWork.and(w -> w._algo != Algo.StackedEnsemble)));
            searchCriteria.set_max_runtime_secs(searchCriteria.max_runtime_secs() == 0.0 ? maxAssignedTimeSecs : Math.min(searchCriteria.max_runtime_secs(), maxAssignedTimeSecs));
            searchCriteria.set_max_models(searchCriteria.max_models() == 0 ? maxAssignedModels : Math.min(searchCriteria.max_models(), maxAssignedModels));
            if (null == key) {
                key = this.makeKey(this._algo.name(), true);
            }
            this.aml().trackKey(key);
            Log.debug((Object[])new Object[]{"Hyperparameter search: " + this._algo.name() + ", time remaining (ms): " + this.aml().timeRemainingMs()});
            this.aml().eventLog().debug(EventLogEntry.Stage.ModelTraining, searchCriteria.max_runtime_secs() == 0.0 ? "No time limitation for " + key : "Time assigned for " + key + ": " + searchCriteria.max_runtime_secs() + "s");
            return this.startSearch(key, baseParms, searchParms, (HyperSpaceSearchCriteria)searchCriteria);
        }
    }

    public static abstract class ModelStep<M extends Model>
    extends ModelingStep<M> {
        public static final int DEFAULT_MODEL_TRAINING_WEIGHT = 10;

        public ModelStep(IAlgo algo, String id, int cost, AutoML autoML) {
            super(algo, id, cost, autoML);
        }

        @Override
        protected abstract Job<M> startJob();

        @Override
        protected WorkAllocations.Work makeWork() {
            return new WorkAllocations.Work(this._id, this._algo, WorkAllocations.JobType.ModelBuild, this._weight);
        }

        @Override
        protected WorkAllocations.Work getAllocatedWork() {
            return this.getWorkAllocations().getAllocation(this._id, this._algo);
        }

        @Override
        protected Key<M> makeKey(String name, boolean withCounter) {
            return this.aml().makeKey(name, null, withCounter);
        }

        protected Job<M> trainModel(Model.Parameters parms) {
            return this.trainModel(null, parms);
        }

        protected Job<M> trainModel(Key<M> key, Model.Parameters parms) {
            String algoName = ModelBuilder.algoName((String)this._algo.urlName());
            if (null == key) {
                key = this.makeKey(algoName, true);
            }
            Model.Parameters defaults = ModelBuilder.make((String)this._algo.urlName(), null, null)._parms;
            this.initTimeConstraints(parms, 0.0);
            this.setCommonModelBuilderParams(parms);
            this.setSeed(parms, defaults, SeedPolicy.Incremental);
            this.setStoppingCriteria(parms, defaults);
            this.setCustomParams(parms);
            if (ArrayUtils.contains((Object[])this._ignoredConstraints, (Object)((Object)AutoML.Constraint.TIMEOUT))) {
                parms._max_runtime_secs = 0.0;
            } else {
                WorkAllocations.Work work = this.getAllocatedWork();
                double maxAssignedTimeSecs = (double)((float)this.aml().timeRemainingMs() * this.getWorkAllocations().remainingWorkRatio(work)) / 1000.0;
                parms._max_runtime_secs = parms._max_runtime_secs == 0.0 ? maxAssignedTimeSecs : Math.min(parms._max_runtime_secs, maxAssignedTimeSecs);
            }
            Log.debug((Object[])new Object[]{"Training model: " + algoName + ", time remaining (ms): " + this.aml().timeRemainingMs()});
            this.aml().eventLog().debug(EventLogEntry.Stage.ModelTraining, parms._max_runtime_secs == 0.0 ? "No time limitation for " + key : "Time assigned for " + key + ": " + parms._max_runtime_secs + "s");
            return this.startModel(key, parms);
        }
    }

    protected static enum SeedPolicy {
        None,
        Global,
        Incremental;

    }
}

