/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.automl.modeling;

import ai.h2o.automl.Algo;
import ai.h2o.automl.AutoML;
import ai.h2o.automl.ModelParametersProvider;
import ai.h2o.automl.ModelSelectionStrategies;
import ai.h2o.automl.ModelSelectionStrategy;
import ai.h2o.automl.ModelingStep;
import ai.h2o.automl.ModelingSteps;
import ai.h2o.automl.ModelingStepsProvider;
import ai.h2o.automl.Models;
import ai.h2o.automl.events.EventLogEntry;
import hex.Model;
import hex.grid.Grid;
import hex.tree.SharedTreeModel;
import hex.tree.gbm.GBMModel;
import java.util.HashMap;
import java.util.Objects;
import water.Job;
import water.Key;

public class GBMStepsProvider
implements ModelingStepsProvider<GBMSteps>,
ModelParametersProvider<GBMModel.GBMParameters> {
    @Override
    public String getName() {
        return Algo.GBM.name();
    }

    @Override
    public GBMSteps newInstance(AutoML aml) {
        return new GBMSteps(aml);
    }

    @Override
    public GBMModel.GBMParameters newDefaultParameters() {
        return new GBMModel.GBMParameters();
    }

    public static class GBMSteps
    extends ModelingSteps {
        private ModelingStep[] defaults = new GBMModelStep[]{new GBMModelStep("def_1", 10, this.aml()){

            @Override
            protected Job<GBMModel> startJob() {
                GBMModel.GBMParameters gbmParameters = this.prepareModelParameters();
                gbmParameters._max_depth = 6;
                gbmParameters._min_rows = 1.0;
                return this.trainModel((Model.Parameters)gbmParameters);
            }
        }, new GBMModelStep("def_2", 10, this.aml()){

            @Override
            protected Job<GBMModel> startJob() {
                GBMModel.GBMParameters gbmParameters = this.prepareModelParameters();
                gbmParameters._max_depth = 7;
                gbmParameters._min_rows = 10.0;
                return this.trainModel((Model.Parameters)gbmParameters);
            }
        }, new GBMModelStep("def_3", 10, this.aml()){

            @Override
            protected Job<GBMModel> startJob() {
                GBMModel.GBMParameters gbmParameters = this.prepareModelParameters();
                gbmParameters._max_depth = 8;
                gbmParameters._min_rows = 10.0;
                return this.trainModel((Model.Parameters)gbmParameters);
            }
        }, new GBMModelStep("def_4", 10, this.aml()){

            @Override
            protected Job<GBMModel> startJob() {
                GBMModel.GBMParameters gbmParameters = this.prepareModelParameters();
                gbmParameters._max_depth = 10;
                gbmParameters._min_rows = 10.0;
                return this.trainModel((Model.Parameters)gbmParameters);
            }
        }, new GBMModelStep("def_5", 10, this.aml()){

            @Override
            protected Job<GBMModel> startJob() {
                GBMModel.GBMParameters gbmParameters = this.prepareModelParameters();
                gbmParameters._max_depth = 15;
                gbmParameters._min_rows = 100.0;
                return this.trainModel((Model.Parameters)gbmParameters);
            }
        }};
        private ModelingStep[] grids = new GBMGridStep[]{new GBMGridStep("grid_1", 60, this.aml()){

            @Override
            protected Job<Grid> startJob() {
                GBMModel.GBMParameters gbmParameters = this.prepareModelParameters();
                HashMap<String, Object[]> searchParams = new HashMap<String, Object[]>();
                searchParams.put("_max_depth", new Integer[]{3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
                searchParams.put("_min_rows", new Integer[]{1, 5, 10, 15, 30, 100});
                searchParams.put("_sample_rate", new Double[]{0.5, 0.6, 0.7, 0.8, 0.9, 1.0});
                searchParams.put("_col_sample_rate", new Double[]{0.4, 0.7, 1.0});
                searchParams.put("_col_sample_rate_per_tree", new Double[]{0.4, 0.7, 1.0});
                searchParams.put("_min_split_improvement", new Double[]{1.0E-4, 1.0E-5});
                return this.hyperparameterSearch((Model.Parameters)gbmParameters, searchParams);
            }
        }};
        private ModelingStep[] exploitation = new ModelingStep[]{new GBMExploitationStep("lr_annealing", 10, this.aml()){
            Key<Models> resultKey;
            {
                super(id, weight, autoML);
                this.resultKey = null;
            }

            @Override
            protected Job<Models> startTraining(Key result, double maxRuntimeSecs) {
                this.resultKey = result;
                GBMModel bestGBM = this.getBestGBM();
                this.aml().eventLog().info(EventLogEntry.Stage.ModelSelection, "Retraining best GBM with learning rate annealing: " + bestGBM._key);
                GBMModel.GBMParameters gbmParameters = (GBMModel.GBMParameters)((GBMModel.GBMParameters)bestGBM._parms).clone();
                gbmParameters._ntrees = 10000;
                gbmParameters._max_runtime_secs = 0.0;
                gbmParameters._learn_rate_annealing = 0.99;
                this.initTimeConstraints((Model.Parameters)gbmParameters, maxRuntimeSecs);
                this.setStoppingCriteria((Model.Parameters)gbmParameters, (Model.Parameters)new GBMModel.GBMParameters());
                return this.asModelsJob(this.startModel(Key.make((String)(result + "_model")), gbmParameters), (Key<Models>)result);
            }

            @Override
            protected ModelSelectionStrategy getSelectionStrategy() {
                return (originalModels, newModels) -> new ModelSelectionStrategies.KeepBestN(1, () -> this.makeTmpLeaderboard(Objects.toString(this.resultKey, this._algo + "_" + this._id))).select(new Key[]{this.getBestGBM()._key}, newModels);
            }
        }};

        static GBMModel.GBMParameters prepareModelParameters() {
            GBMModel.GBMParameters gbmParameters = new GBMModel.GBMParameters();
            gbmParameters._score_tree_interval = 5;
            gbmParameters._histogram_type = SharedTreeModel.SharedTreeParameters.HistogramType.AUTO;
            return gbmParameters;
        }

        public GBMSteps(AutoML autoML) {
            super(autoML);
        }

        @Override
        protected ModelingStep[] getDefaultModels() {
            return this.defaults;
        }

        @Override
        protected ModelingStep[] getGrids() {
            return this.grids;
        }

        @Override
        protected ModelingStep[] getExploitation() {
            return this.exploitation;
        }

        static abstract class GBMExploitationStep
        extends ModelingStep.SelectionStep<GBMModel> {
            protected GBMModel getBestGBM() {
                for (Model model : this.getTrainedModels()) {
                    if (!(model instanceof GBMModel)) continue;
                    return (GBMModel)model;
                }
                return null;
            }

            @Override
            protected boolean canRun() {
                return super.canRun() && this.getBestGBM() != null;
            }

            public GBMExploitationStep(String id, int weight, AutoML autoML) {
                super(Algo.GBM, id, weight, autoML);
            }
        }

        static abstract class GBMGridStep
        extends ModelingStep.GridStep<GBMModel> {
            public GBMGridStep(String id, int weight, AutoML autoML) {
                super(Algo.GBM, id, weight, autoML);
            }

            GBMModel.GBMParameters prepareModelParameters() {
                GBMModel.GBMParameters gbmParameters = GBMSteps.prepareModelParameters();
                gbmParameters._ntrees = 10000;
                return gbmParameters;
            }
        }

        static abstract class GBMModelStep
        extends ModelingStep.ModelStep<GBMModel> {
            GBMModelStep(String id, int weight, AutoML autoML) {
                super(Algo.GBM, id, weight, autoML);
            }

            GBMModel.GBMParameters prepareModelParameters() {
                GBMModel.GBMParameters gbmParameters = GBMSteps.prepareModelParameters();
                gbmParameters._ntrees = 10000;
                gbmParameters._sample_rate = 0.8;
                gbmParameters._col_sample_rate = 0.8;
                gbmParameters._col_sample_rate_per_tree = 0.8;
                return gbmParameters;
            }
        }
    }
}

