/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.automl.modeling;

import ai.h2o.automl.Algo;
import ai.h2o.automl.AutoML;
import ai.h2o.automl.ModelingStep;
import ai.h2o.automl.ModelingSteps;
import ai.h2o.automl.ModelingStepsProvider;
import hex.Model;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.Grid;
import hex.tree.xgboost.XGBoostModel;
import java.util.HashMap;
import water.Job;

public class XGBoostStepsProvider
implements ModelingStepsProvider<XGBoostSteps> {
    @Override
    public String getName() {
        return Algo.XGBoost.name();
    }

    @Override
    public XGBoostSteps newInstance(AutoML aml) {
        return new XGBoostSteps(aml);
    }

    public static class XGBoostSteps
    extends ModelingSteps {
        private ModelingStep[] defaults = new XGBoostModelStep[]{new XGBoostModelStep("def_1", 10, this.aml(), false){

            @Override
            protected Job<XGBoostModel> startJob() {
                XGBoostModel.XGBoostParameters xgBoostParameters = this.prepareModelParameters();
                xgBoostParameters._max_depth = 10;
                xgBoostParameters._min_rows = 5.0;
                xgBoostParameters._sample_rate = 0.6;
                xgBoostParameters._col_sample_rate = 0.8;
                xgBoostParameters._col_sample_rate_per_tree = 0.8;
                if (this._emulateLightGBM) {
                    xgBoostParameters._max_leaves = 1 << xgBoostParameters._max_depth;
                    xgBoostParameters._max_depth *= 2;
                    xgBoostParameters._min_sum_hessian_in_leaf = (float)xgBoostParameters._min_rows;
                }
                return this.trainModel((Model.Parameters)xgBoostParameters);
            }
        }, new XGBoostModelStep("def_2", 10, this.aml(), false){

            @Override
            protected Job<XGBoostModel> startJob() {
                XGBoostModel.XGBoostParameters xgBoostParameters = this.prepareModelParameters();
                xgBoostParameters._max_depth = 20;
                xgBoostParameters._min_rows = 10.0;
                xgBoostParameters._sample_rate = 0.6;
                xgBoostParameters._col_sample_rate = 0.8;
                xgBoostParameters._col_sample_rate_per_tree = 0.8;
                if (this._emulateLightGBM) {
                    xgBoostParameters._max_leaves = 1 << xgBoostParameters._max_depth;
                    xgBoostParameters._max_depth *= 2;
                    xgBoostParameters._min_sum_hessian_in_leaf = (float)xgBoostParameters._min_rows;
                }
                return this.trainModel((Model.Parameters)xgBoostParameters);
            }
        }, new XGBoostModelStep("def_3", 10, this.aml(), false){

            @Override
            protected Job<XGBoostModel> startJob() {
                XGBoostModel.XGBoostParameters xgBoostParameters = this.prepareModelParameters();
                xgBoostParameters._max_depth = 5;
                xgBoostParameters._min_rows = 3.0;
                xgBoostParameters._sample_rate = 0.8;
                xgBoostParameters._col_sample_rate = 0.8;
                xgBoostParameters._col_sample_rate_per_tree = 0.8;
                if (this._emulateLightGBM) {
                    xgBoostParameters._max_leaves = 1 << xgBoostParameters._max_depth;
                    xgBoostParameters._max_depth *= 2;
                    xgBoostParameters._min_sum_hessian_in_leaf = (float)xgBoostParameters._min_rows;
                }
                return this.trainModel((Model.Parameters)xgBoostParameters);
            }
        }};
        private ModelingStep[] grids = new XGBoostGridStep[]{new XGBoostGridStep("grid_1", 100, this.aml(), false){

            @Override
            protected Job<Grid> startJob() {
                XGBoostModel.XGBoostParameters xgBoostParameters = this.prepareModelParameters();
                HashMap<String, Object[]> searchParams = new HashMap<String, Object[]>();
                if (this._emulateLightGBM) {
                    searchParams.put("_max_leaves", new Integer[]{32, 1024, 32768, 0x100000});
                    searchParams.put("_max_depth", new Integer[]{10, 20, 50});
                    searchParams.put("_min_sum_hessian_in_leaf", new Double[]{0.01, 0.1, 1.0, 3.0, 5.0, 10.0, 15.0, 20.0});
                } else {
                    searchParams.put("_max_depth", new Integer[]{5, 10, 15, 20});
                    searchParams.put("_min_rows", new Double[]{0.01, 0.1, 1.0, 3.0, 5.0, 10.0, 15.0, 20.0});
                }
                searchParams.put("_sample_rate", new Double[]{0.6, 0.8, 1.0});
                searchParams.put("_col_sample_rate", new Double[]{0.6, 0.8, 1.0});
                searchParams.put("_col_sample_rate_per_tree", new Double[]{0.7, 0.8, 0.9, 1.0});
                searchParams.put("_booster", new XGBoostModel.XGBoostParameters.Booster[]{XGBoostModel.XGBoostParameters.Booster.gbtree, XGBoostModel.XGBoostParameters.Booster.gbtree, XGBoostModel.XGBoostParameters.Booster.dart});
                searchParams.put("_reg_lambda", new Float[]{Float.valueOf(0.001f), Float.valueOf(0.01f), Float.valueOf(0.1f), Float.valueOf(1.0f), Float.valueOf(10.0f), Float.valueOf(100.0f)});
                searchParams.put("_reg_alpha", new Float[]{Float.valueOf(0.001f), Float.valueOf(0.01f), Float.valueOf(0.1f), Float.valueOf(0.5f), Float.valueOf(1.0f)});
                return this.hyperparameterSearch((Model.Parameters)xgBoostParameters, searchParams);
            }
        }};

        static XGBoostModel.XGBoostParameters prepareModelParameters(AutoML aml, boolean emulateLightGBM) {
            XGBoostModel.XGBoostParameters xgBoostParameters = new XGBoostModel.XGBoostParameters();
            if (emulateLightGBM) {
                xgBoostParameters._tree_method = XGBoostModel.XGBoostParameters.TreeMethod.hist;
                xgBoostParameters._grow_policy = XGBoostModel.XGBoostParameters.GrowPolicy.lossguide;
            }
            xgBoostParameters._distribution = aml.getResponseColumn().isBinary() && !aml.getResponseColumn().isNumeric() ? DistributionFamily.bernoulli : (aml.getResponseColumn().isCategorical() ? DistributionFamily.multinomial : DistributionFamily.AUTO);
            xgBoostParameters._score_tree_interval = 5;
            xgBoostParameters._stopping_rounds = 5;
            xgBoostParameters._ntrees = 10000;
            xgBoostParameters._learn_rate = 0.05;
            return xgBoostParameters;
        }

        public XGBoostSteps(AutoML autoML) {
            super(autoML);
        }

        @Override
        protected ModelingStep[] getDefaultModels() {
            return this.defaults;
        }

        @Override
        protected ModelingStep[] getGrids() {
            return this.grids;
        }

        static abstract class XGBoostGridStep
        extends ModelingStep.GridStep<XGBoostModel> {
            boolean _emulateLightGBM;

            public XGBoostGridStep(String id, int weight, AutoML autoML, boolean emulateLightGBM) {
                super(Algo.XGBoost, id, weight, autoML);
                this._emulateLightGBM = emulateLightGBM;
            }

            XGBoostModel.XGBoostParameters prepareModelParameters() {
                return XGBoostSteps.prepareModelParameters(this.aml(), this._emulateLightGBM);
            }
        }

        static abstract class XGBoostModelStep
        extends ModelingStep.ModelStep<XGBoostModel> {
            boolean _emulateLightGBM;

            XGBoostModelStep(String id, int weight, AutoML autoML, boolean emulateLightGBM) {
                super(Algo.XGBoost, id, weight, autoML);
                this._emulateLightGBM = emulateLightGBM;
            }

            XGBoostModel.XGBoostParameters prepareModelParameters() {
                return XGBoostSteps.prepareModelParameters(this.aml(), this._emulateLightGBM);
            }
        }
    }
}

