/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.automl;

import ai.h2o.automl.Algo;
import ai.h2o.automl.AutoML;
import ai.h2o.automl.AutoMLBuildSpec;
import ai.h2o.automl.StepDefinition;
import ai.h2o.automl.WorkAllocations;
import ai.h2o.automl.events.EventLogEntry;
import ai.h2o.automl.leaderboard.Leaderboard;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelParametersBuilderFactory;
import hex.ScoreKeeper;
import hex.ensemble.StackedEnsembleModel;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import java.util.Map;
import water.Iced;
import water.Job;
import water.Key;
import water.exceptions.H2OIllegalArgumentException;
import water.util.EnumUtils;
import water.util.Log;

public abstract class ModelingStep<M extends Model>
extends Iced<ModelingStep> {
    private transient AutoML _aml;
    protected final Algo _algo;
    protected final String _id;
    protected int _weight;
    protected boolean _ignoreConstraints;
    protected String _description;
    StepDefinition _fromDef;

    protected ModelingStep(Algo algo, String id, int weight, AutoML autoML) {
        this._algo = algo;
        this._id = id;
        this._weight = weight;
        this._aml = autoML;
        this._description = algo.name() + " " + id;
    }

    protected abstract WorkAllocations.Work getAllocatedWork();

    protected abstract Key makeKey(String var1, boolean var2);

    protected abstract WorkAllocations.Work makeWork();

    protected abstract Job startJob();

    protected AutoML aml() {
        return this._aml;
    }

    protected boolean canRun() {
        return this.getAllocatedWork() != null;
    }

    protected WorkAllocations getWorkAllocations() {
        return this.aml()._workAllocations;
    }

    protected Model[] getTrainedModels() {
        return this.aml().leaderboard().getModels();
    }

    protected Key<Model>[] getTrainedModelsKeys() {
        return this.aml().leaderboard().getModelKeys();
    }

    protected boolean isCVEnabled() {
        return this.aml().isCVEnabled();
    }

    void setCommonModelBuilderParams(Model.Parameters params) {
        params._train = this.aml()._trainingFrame._key;
        if (null != this.aml()._validationFrame) {
            params._valid = this.aml()._validationFrame._key;
        }
        AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
        params._response_column = buildSpec.input_spec.response_column;
        params._ignored_columns = buildSpec.input_spec.ignored_columns;
        if (!(params instanceof StackedEnsembleModel.StackedEnsembleParameters)) {
            params._keep_cross_validation_predictions = this.aml().getBlendingFrame() == null ? true : buildSpec.build_control.keep_cross_validation_predictions;
            params._fold_column = buildSpec.input_spec.fold_column;
            params._weights_column = buildSpec.input_spec.weights_column;
            if (buildSpec.input_spec.fold_column == null) {
                params._nfolds = buildSpec.build_control.nfolds;
                if (buildSpec.build_control.nfolds > 1) {
                    params._fold_assignment = Model.Parameters.FoldAssignmentScheme.Modulo;
                }
            }
            if (buildSpec.build_control.balance_classes) {
                params._balance_classes = buildSpec.build_control.balance_classes;
                params._class_sampling_factors = buildSpec.build_control.class_sampling_factors;
                params._max_after_balance_size = buildSpec.build_control.max_after_balance_size;
            }
        }
        params._keep_cross_validation_models = buildSpec.build_control.keep_cross_validation_models;
        params._keep_cross_validation_fold_assignment = buildSpec.build_control.nfolds != 0 && buildSpec.build_control.keep_cross_validation_fold_assignment;
        params._export_checkpoints_dir = buildSpec.build_control.export_checkpoints_dir;
    }

    void setCustomParams(Model.Parameters params) {
        AutoMLBuildSpec.AutoMLCustomParameters customParams = this.aml().getBuildSpec().build_models.algo_parameters;
        if (customParams == null) {
            return;
        }
        customParams.applyCustomParameters(this._algo, params);
    }

    void setStoppingCriteria(Model.Parameters parms, Model.Parameters defaults, boolean isIndividualModel) {
        AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
        parms._max_runtime_secs = buildSpec.build_control.stopping_criteria.max_runtime_secs_per_model();
        if (isIndividualModel && parms._seed == defaults._seed && buildSpec.build_control.stopping_criteria.seed() != -1L) {
            parms._seed = buildSpec.build_control.stopping_criteria.seed() + (long)this.aml().individualModelsTrained.getAndIncrement();
        }
        if (parms._stopping_metric == defaults._stopping_metric) {
            parms._stopping_metric = buildSpec.build_control.stopping_criteria.stopping_metric();
        }
        if (parms._stopping_metric == ScoreKeeper.StoppingMetric.AUTO) {
            String sort_metric = this.getSortMetric();
            ScoreKeeper.StoppingMetric stoppingMetric = sort_metric == null ? ScoreKeeper.StoppingMetric.AUTO : (parms._stopping_metric = sort_metric.equals("auc") ? ScoreKeeper.StoppingMetric.logloss : ModelingStep.metricValueOf(sort_metric));
        }
        if (parms._stopping_rounds == defaults._stopping_rounds) {
            parms._stopping_rounds = buildSpec.build_control.stopping_criteria.stopping_rounds();
        }
        if (parms._stopping_tolerance == defaults._stopping_tolerance) {
            parms._stopping_tolerance = buildSpec.build_control.stopping_criteria.stopping_tolerance();
        }
    }

    private String getSortMetric() {
        Leaderboard leaderboard = this.aml().leaderboard();
        return leaderboard == null ? null : leaderboard.getSortMetric();
    }

    private static ScoreKeeper.StoppingMetric metricValueOf(String name) {
        if (name == null) {
            return ScoreKeeper.StoppingMetric.AUTO;
        }
        switch (name) {
            case "mean_residual_deviance": {
                return ScoreKeeper.StoppingMetric.deviance;
            }
        }
        try {
            return (ScoreKeeper.StoppingMetric)EnumUtils.valueOf(ScoreKeeper.StoppingMetric.class, (String)name);
        }
        catch (IllegalArgumentException illegalArgumentException) {
            return ScoreKeeper.StoppingMetric.AUTO;
        }
    }

    public static abstract class GridStep<M extends Model>
    extends ModelingStep<M> {
        public static final int DEFAULT_GRID_TRAINING_WEIGHT = 20;

        public GridStep(Algo algo, String id, int cost, AutoML autoML) {
            super(algo, id, cost, autoML);
        }

        @Override
        protected abstract Job<Grid> startJob();

        @Override
        protected WorkAllocations.Work makeWork() {
            return new WorkAllocations.Work(this._id, this._algo, WorkAllocations.JobType.HyperparamSearch, this._weight);
        }

        @Override
        protected WorkAllocations.Work getAllocatedWork() {
            return this.getWorkAllocations().getAllocation(this._id, this._algo);
        }

        @Override
        protected Key<Grid> makeKey(String name, boolean withCounter) {
            return this.aml().gridKey(name, withCounter);
        }

        protected Job<Grid> hyperparameterSearch(Model.Parameters baseParms, Map<String, Object[]> searchParms) {
            return this.hyperparameterSearch(null, baseParms, searchParms);
        }

        protected Job<Grid> hyperparameterSearch(Key<Grid> key, Model.Parameters baseParms, Map<String, Object[]> searchParms) {
            Model.Parameters defaults;
            try {
                defaults = (Model.Parameters)baseParms.getClass().newInstance();
            }
            catch (Exception e) {
                this.aml().eventLog().warn(EventLogEntry.Stage.ModelTraining, "Internal error doing hyperparameter search");
                throw new H2OIllegalArgumentException("Hyperparameter search can't create a new instance of Model.Parameters subclass: " + baseParms.getClass());
            }
            this.setCommonModelBuilderParams(baseParms);
            this.setStoppingCriteria(baseParms, defaults, false);
            this.setCustomParams(baseParms);
            AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
            HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria searchCriteria = (HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria)buildSpec.build_control.stopping_criteria.getSearchCriteria().clone();
            WorkAllocations.Work work = this.getAllocatedWork();
            double maxAssignedTimeSecs = (double)((float)this.aml().timeRemainingMs() * this.getWorkAllocations().remainingWorkRatio(work)) / 1000.0;
            int maxAssignedModels = (int)Math.ceil((float)this.aml().remainingModels() * this.getWorkAllocations().remainingWorkRatio(work, w -> w._algo != Algo.StackedEnsemble));
            if (searchCriteria.max_runtime_secs() == 0.0) {
                searchCriteria.set_max_runtime_secs(maxAssignedTimeSecs);
            } else {
                searchCriteria.set_max_runtime_secs(Math.min(searchCriteria.max_runtime_secs(), maxAssignedTimeSecs));
            }
            if (searchCriteria.max_models() == 0) {
                searchCriteria.set_max_models(maxAssignedModels);
            } else {
                searchCriteria.set_max_models(Math.min(searchCriteria.max_models(), maxAssignedModels));
            }
            if (null == key) {
                key = this.makeKey(this._algo.name(), true);
            }
            this.aml().addGridKey(key);
            this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, "AutoML: starting " + key + " hyperparameter search");
            Log.debug((Object[])new Object[]{"Hyperparameter search: " + this._algo.name() + ", time remaining (ms): " + this.aml().timeRemainingMs()});
            return GridSearch.startGridSearch(key, (Model.Parameters)baseParms, searchParms, (ModelParametersBuilderFactory)new GridSearch.SimpleParametersBuilderFactory(), (HyperSpaceSearchCriteria)searchCriteria, (int)1);
        }
    }

    public static abstract class ModelStep<M extends Model>
    extends ModelingStep<M> {
        public static final int DEFAULT_MODEL_TRAINING_WEIGHT = 10;

        public ModelStep(Algo algo, String id, int cost, AutoML autoML) {
            super(algo, id, cost, autoML);
        }

        @Override
        protected abstract Job<M> startJob();

        @Override
        protected WorkAllocations.Work makeWork() {
            return new WorkAllocations.Work(this._id, this._algo, WorkAllocations.JobType.ModelBuild, this._weight);
        }

        @Override
        protected WorkAllocations.Work getAllocatedWork() {
            return this.getWorkAllocations().getAllocation(this._id, this._algo);
        }

        @Override
        protected Key<M> makeKey(String name, boolean withCounter) {
            return this.aml().modelKey(name, withCounter);
        }

        protected Job<M> trainModel(Model.Parameters parms) {
            return this.trainModel(null, parms);
        }

        protected Job<M> trainModel(Key<M> key, Model.Parameters parms) {
            String algoName = ModelBuilder.algoName((String)this._algo.urlName());
            if (null == key) {
                key = this.makeKey(algoName, true);
            }
            Job job = new Job(key, ModelBuilder.javaName((String)this._algo.urlName()), this._description);
            ModelBuilder builder = ModelBuilder.make((String)this._algo.urlName(), (Job)job, key);
            Model.Parameters defaults = builder._parms;
            builder._parms = parms;
            this.setCommonModelBuilderParams(builder._parms);
            this.setStoppingCriteria(builder._parms, defaults, true);
            this.setCustomParams(builder._parms);
            builder._parms._max_runtime_secs = this._ignoreConstraints ? 0.0 : (builder._parms._max_runtime_secs == 0.0 ? (double)this.aml().timeRemainingMs() / 1000.0 : Math.min(builder._parms._max_runtime_secs, (double)this.aml().timeRemainingMs() / 1000.0));
            this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, "AutoML: starting " + key + " model training");
            builder.init(false);
            Log.debug((Object[])new Object[]{"Training model: " + algoName + ", time remaining (ms): " + this.aml().timeRemainingMs()});
            try {
                return builder.trainModelOnH2ONode();
            }
            catch (H2OIllegalArgumentException exception) {
                this.aml().eventLog().warn(EventLogEntry.Stage.ModelTraining, "Skipping training of model " + key + " due to exception: " + (Object)((Object)exception));
                return null;
            }
        }
    }
}

