/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.automl.modeling;

import ai.h2o.automl.Algo;
import ai.h2o.automl.AutoML;
import ai.h2o.automl.ModelingStep;
import ai.h2o.automl.ModelingSteps;
import ai.h2o.automl.ModelingStepsProvider;
import hex.Model;
import hex.grid.Grid;
import hex.tree.SharedTreeModel;
import hex.tree.gbm.GBMModel;
import java.util.HashMap;
import water.Job;

public class GBMStepsProvider
implements ModelingStepsProvider<GBMSteps> {
    @Override
    public String getName() {
        return Algo.GBM.name();
    }

    @Override
    public GBMSteps newInstance(AutoML aml) {
        return new GBMSteps(aml);
    }

    public static class GBMSteps
    extends ModelingSteps {
        private ModelingStep[] defaults = new GBMModelStep[]{new GBMModelStep("def_1", 10, this.aml()){

            @Override
            protected Job<GBMModel> startJob() {
                GBMModel.GBMParameters gbmParameters = this.prepareModelParameters();
                gbmParameters._max_depth = 6;
                gbmParameters._min_rows = 1.0;
                return this.trainModel((Model.Parameters)gbmParameters);
            }
        }, new GBMModelStep("def_2", 10, this.aml()){

            @Override
            protected Job<GBMModel> startJob() {
                GBMModel.GBMParameters gbmParameters = this.prepareModelParameters();
                gbmParameters._max_depth = 7;
                gbmParameters._min_rows = 10.0;
                return this.trainModel((Model.Parameters)gbmParameters);
            }
        }, new GBMModelStep("def_3", 10, this.aml()){

            @Override
            protected Job<GBMModel> startJob() {
                GBMModel.GBMParameters gbmParameters = this.prepareModelParameters();
                gbmParameters._max_depth = 8;
                gbmParameters._min_rows = 10.0;
                return this.trainModel((Model.Parameters)gbmParameters);
            }
        }, new GBMModelStep("def_4", 10, this.aml()){

            @Override
            protected Job<GBMModel> startJob() {
                GBMModel.GBMParameters gbmParameters = this.prepareModelParameters();
                gbmParameters._max_depth = 10;
                gbmParameters._min_rows = 10.0;
                return this.trainModel((Model.Parameters)gbmParameters);
            }
        }, new GBMModelStep("def_5", 10, this.aml()){

            @Override
            protected Job<GBMModel> startJob() {
                GBMModel.GBMParameters gbmParameters = this.prepareModelParameters();
                gbmParameters._max_depth = 15;
                gbmParameters._min_rows = 100.0;
                return this.trainModel((Model.Parameters)gbmParameters);
            }
        }};
        private ModelingStep[] grids = new GBMGridStep[]{new GBMGridStep("grid_1", 60, this.aml()){

            @Override
            protected Job<Grid> startJob() {
                GBMModel.GBMParameters gbmParameters = this.prepareModelParameters();
                HashMap<String, Object[]> searchParams = new HashMap<String, Object[]>();
                searchParams.put("_max_depth", new Integer[]{3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
                searchParams.put("_min_rows", new Integer[]{1, 5, 10, 15, 30, 100});
                searchParams.put("_learn_rate", new Double[]{0.001, 0.005, 0.008, 0.01, 0.05, 0.08, 0.1, 0.5, 0.8});
                searchParams.put("_sample_rate", new Double[]{0.5, 0.6, 0.7, 0.8, 0.9, 1.0});
                searchParams.put("_col_sample_rate", new Double[]{0.4, 0.7, 1.0});
                searchParams.put("_col_sample_rate_per_tree", new Double[]{0.4, 0.7, 1.0});
                searchParams.put("_min_split_improvement", new Double[]{1.0E-4, 1.0E-5});
                return this.hyperparameterSearch((Model.Parameters)gbmParameters, searchParams);
            }
        }};

        static GBMModel.GBMParameters prepareModelParameters() {
            GBMModel.GBMParameters gbmParameters = new GBMModel.GBMParameters();
            gbmParameters._score_tree_interval = 5;
            gbmParameters._histogram_type = SharedTreeModel.SharedTreeParameters.HistogramType.AUTO;
            return gbmParameters;
        }

        public GBMSteps(AutoML autoML) {
            super(autoML);
        }

        @Override
        protected ModelingStep[] getDefaultModels() {
            return this.defaults;
        }

        @Override
        protected ModelingStep[] getGrids() {
            return this.grids;
        }

        static abstract class GBMGridStep
        extends ModelingStep.GridStep<GBMModel> {
            public GBMGridStep(String id, int weight, AutoML autoML) {
                super(Algo.GBM, id, weight, autoML);
            }

            GBMModel.GBMParameters prepareModelParameters() {
                GBMModel.GBMParameters gbmParameters = GBMSteps.prepareModelParameters();
                gbmParameters._ntrees = 10000;
                return gbmParameters;
            }
        }

        static abstract class GBMModelStep
        extends ModelingStep.ModelStep<GBMModel> {
            GBMModelStep(String id, int weight, AutoML autoML) {
                super(Algo.GBM, id, weight, autoML);
            }

            GBMModel.GBMParameters prepareModelParameters() {
                GBMModel.GBMParameters gbmParameters = GBMSteps.prepareModelParameters();
                gbmParameters._ntrees = 1000;
                gbmParameters._sample_rate = 0.8;
                gbmParameters._col_sample_rate = 0.8;
                gbmParameters._col_sample_rate_per_tree = 0.8;
                return gbmParameters;
            }
        }
    }
}

