/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.automl.modeling;

import ai.h2o.automl.Algo;
import ai.h2o.automl.AutoML;
import ai.h2o.automl.ModelSelectionStrategies;
import ai.h2o.automl.ModelSelectionStrategy;
import ai.h2o.automl.ModelingStep;
import ai.h2o.automl.ModelingSteps;
import ai.h2o.automl.Models;
import ai.h2o.automl.events.EventLogEntry;
import hex.Model;
import hex.ModelParametersBuilderFactory;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import hex.grid.HyperSpaceWalker;
import hex.grid.SequentialWalker;
import hex.tree.xgboost.XGBoostModel;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import water.Job;
import water.Key;

public class XGBoostSteps
extends ModelingSteps {
    private ModelingStep[] defaults = new XGBoostModelStep[]{new XGBoostModelStep("def_1", 10, this.aml(), false){

        @Override
        protected Job<XGBoostModel> startJob() {
            XGBoostModel.XGBoostParameters xgBoostParameters = this.prepareModelParameters();
            xgBoostParameters._max_depth = 10;
            xgBoostParameters._min_rows = 5.0;
            xgBoostParameters._sample_rate = 0.6;
            xgBoostParameters._col_sample_rate = 0.8;
            xgBoostParameters._col_sample_rate_per_tree = 0.8;
            if (this._emulateLightGBM) {
                xgBoostParameters._max_leaves = 1 << xgBoostParameters._max_depth;
                xgBoostParameters._max_depth *= 2;
            }
            return this.trainModel((Model.Parameters)xgBoostParameters);
        }
    }, new XGBoostModelStep("def_2", 10, this.aml(), false){

        @Override
        protected Job<XGBoostModel> startJob() {
            XGBoostModel.XGBoostParameters xgBoostParameters = this.prepareModelParameters();
            xgBoostParameters._max_depth = 20;
            xgBoostParameters._min_rows = 10.0;
            xgBoostParameters._sample_rate = 0.6;
            xgBoostParameters._col_sample_rate = 0.8;
            xgBoostParameters._col_sample_rate_per_tree = 0.8;
            if (this._emulateLightGBM) {
                xgBoostParameters._max_leaves = 1 << xgBoostParameters._max_depth;
                xgBoostParameters._max_depth *= 2;
            }
            return this.trainModel((Model.Parameters)xgBoostParameters);
        }
    }, new XGBoostModelStep("def_3", 10, this.aml(), false){

        @Override
        protected Job<XGBoostModel> startJob() {
            XGBoostModel.XGBoostParameters xgBoostParameters = this.prepareModelParameters();
            xgBoostParameters._max_depth = 5;
            xgBoostParameters._min_rows = 3.0;
            xgBoostParameters._sample_rate = 0.8;
            xgBoostParameters._col_sample_rate = 0.8;
            xgBoostParameters._col_sample_rate_per_tree = 0.8;
            if (this._emulateLightGBM) {
                xgBoostParameters._max_leaves = 1 << xgBoostParameters._max_depth;
                xgBoostParameters._max_depth *= 2;
            }
            return this.trainModel((Model.Parameters)xgBoostParameters);
        }
    }};
    private ModelingStep[] grids = new XGBoostGridStep[]{new XGBoostGridStep("grid_1", 100, this.aml(), false){

        @Override
        protected Job<Grid> startJob() {
            XGBoostModel.XGBoostParameters xgBoostParameters = this.prepareModelParameters();
            HashMap<String, Object[]> searchParams = new HashMap<String, Object[]>();
            if (this._emulateLightGBM) {
                searchParams.put("_max_leaves", new Integer[]{32, 1024, 32768, 0x100000});
                searchParams.put("_max_depth", new Integer[]{10, 20, 50});
            } else {
                searchParams.put("_max_depth", new Integer[]{5, 10, 15, 20});
                if (this.aml().getWeightsColumn() == null || this.aml().getWeightsColumn().isInt()) {
                    searchParams.put("_min_rows", new Double[]{1.0, 3.0, 5.0, 10.0, 15.0, 20.0});
                } else {
                    searchParams.put("_min_rows", new Double[]{0.01, 0.1, 1.0, 3.0, 5.0, 10.0, 15.0, 20.0});
                }
            }
            searchParams.put("_sample_rate", new Double[]{0.6, 0.8, 1.0});
            searchParams.put("_col_sample_rate", new Double[]{0.6, 0.8, 1.0});
            searchParams.put("_col_sample_rate_per_tree", new Double[]{0.7, 0.8, 0.9, 1.0});
            searchParams.put("_booster", new XGBoostModel.XGBoostParameters.Booster[]{XGBoostModel.XGBoostParameters.Booster.gbtree, XGBoostModel.XGBoostParameters.Booster.gbtree, XGBoostModel.XGBoostParameters.Booster.dart});
            searchParams.put("_reg_lambda", new Float[]{Float.valueOf(0.001f), Float.valueOf(0.01f), Float.valueOf(0.1f), Float.valueOf(1.0f), Float.valueOf(10.0f), Float.valueOf(100.0f)});
            searchParams.put("_reg_alpha", new Float[]{Float.valueOf(0.001f), Float.valueOf(0.01f), Float.valueOf(0.1f), Float.valueOf(0.5f), Float.valueOf(1.0f)});
            return this.hyperparameterSearch((Model.Parameters)xgBoostParameters, searchParams);
        }
    }};
    private ModelingStep[] exploitation = new ModelingStep[]{new XGBoostExploitationStep("lr_annealing", 10, this.aml(), false){
        Key<Models> resultKey;
        {
            this.resultKey = null;
        }

        @Override
        protected Job<Models> startTraining(Key result, double maxRuntimeSecs) {
            this.resultKey = result;
            XGBoostModel bestXGB = this.getBestXGB();
            this.aml().eventLog().info(EventLogEntry.Stage.ModelSelection, "Retraining best XGBoost with learning rate annealing: " + bestXGB._key);
            XGBoostModel.XGBoostParameters xgBoostParameters = (XGBoostModel.XGBoostParameters)((XGBoostModel.XGBoostParameters)bestXGB._parms).clone();
            xgBoostParameters._ntrees = 10000;
            xgBoostParameters._max_runtime_secs = 0.0;
            xgBoostParameters._learn_rate_annealing = 0.99;
            this.initTimeConstraints((Model.Parameters)xgBoostParameters, maxRuntimeSecs);
            this.setStoppingCriteria((Model.Parameters)xgBoostParameters, (Model.Parameters)new XGBoostModel.XGBoostParameters());
            return this.asModelsJob(this.startModel(Key.make((String)(result + "_model")), xgBoostParameters), (Key<Models>)result);
        }

        @Override
        protected ModelSelectionStrategy getSelectionStrategy() {
            return (originalModels, newModels) -> new ModelSelectionStrategies.KeepBestN(1, () -> this.makeTmpLeaderboard(Objects.toString(this.resultKey, this._algo + "_" + this._id))).select(new Key[]{this.getBestXGB()._key}, newModels);
        }
    }, new XGBoostExploitationStep("lr_search", 40, this.aml(), false){
        Key resultKey;
        {
            this.resultKey = null;
        }

        @Override
        protected ModelSelectionStrategy getSelectionStrategy() {
            return (originalModels, newModels) -> new ModelSelectionStrategies.KeepBestN(1, () -> this.makeTmpLeaderboard(Objects.toString(this.resultKey, this._algo + "_" + this._id))).select(new Key[]{this.getBestXGB()._key}, newModels);
        }

        @Override
        protected Job<Models> startTraining(Key result, double maxRuntimeSecs) {
            this.resultKey = result;
            XGBoostModel bestXGB = this.getBestXGBs(1).get(0);
            this.aml().eventLog().info(EventLogEntry.Stage.ModelSelection, "Applying learning rate search on best XGBoost: " + bestXGB._key);
            XGBoostModel.XGBoostParameters xgBoostParameters = (XGBoostModel.XGBoostParameters)((XGBoostModel.XGBoostParameters)bestXGB._parms).clone();
            XGBoostModel.XGBoostParameters defaults = new XGBoostModel.XGBoostParameters();
            xgBoostParameters._ntrees = 10000;
            xgBoostParameters._max_runtime_secs = 0.0;
            this.initTimeConstraints((Model.Parameters)xgBoostParameters, 0.0);
            this.setStoppingCriteria((Model.Parameters)xgBoostParameters, (Model.Parameters)defaults);
            xgBoostParameters._eta = defaults._eta;
            int sti = xgBoostParameters._score_tree_interval;
            Object[][] hyperParams = new Object[][]{{"_learn_rate", "_score_tree_interval"}, {0.5, sti}, {0.2, 2 * sti}, {0.1, 3 * sti}, {0.05, 4 * sti}, {0.02, 5 * sti}, {0.01, 6 * sti}};
            this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, "AutoML: starting " + this.resultKey + " model training").setNamedValue("start_" + this._algo + "_" + this._id, new Date(), EventLogEntry.epochFormat.get());
            return this.asModelsJob(GridSearch.startGridSearch((Key)Key.make((String)(result + "_grid")), (HyperSpaceWalker)new SequentialWalker((Model.Parameters)xgBoostParameters, hyperParams, (ModelParametersBuilderFactory)new GridSearch.SimpleParametersBuilderFactory(), new HyperSpaceSearchCriteria.SequentialSearchCriteria(HyperSpaceSearchCriteria.StoppingCriteria.create().maxRuntimeSecs((double)((int)maxRuntimeSecs)).stoppingMetric(xgBoostParameters._stopping_metric).stoppingRounds(2).stoppingTolerance(xgBoostParameters._stopping_tolerance).build())), (int)1), (Key<Models>)result);
        }
    }};

    static XGBoostModel.XGBoostParameters prepareModelParameters(AutoML aml, boolean emulateLightGBM) {
        XGBoostModel.XGBoostParameters xgBoostParameters = new XGBoostModel.XGBoostParameters();
        if (emulateLightGBM) {
            xgBoostParameters._tree_method = XGBoostModel.XGBoostParameters.TreeMethod.hist;
            xgBoostParameters._grow_policy = XGBoostModel.XGBoostParameters.GrowPolicy.lossguide;
        }
        xgBoostParameters._distribution = aml.getResponseColumn().isBinary() && !aml.getResponseColumn().isNumeric() ? DistributionFamily.bernoulli : (aml.getResponseColumn().isCategorical() ? DistributionFamily.multinomial : DistributionFamily.AUTO);
        xgBoostParameters._score_tree_interval = 5;
        xgBoostParameters._ntrees = 10000;
        return xgBoostParameters;
    }

    public XGBoostSteps(AutoML autoML) {
        super(autoML);
    }

    @Override
    protected ModelingStep[] getDefaultModels() {
        return this.defaults;
    }

    @Override
    protected ModelingStep[] getGrids() {
        return this.grids;
    }

    @Override
    protected ModelingStep[] getExploitation() {
        return this.exploitation;
    }

    static abstract class XGBoostExploitationStep
    extends ModelingStep.SelectionStep<XGBoostModel> {
        boolean _emulateLightGBM;

        protected XGBoostModel getBestXGB() {
            return this.getBestXGBs(1).get(0);
        }

        protected List<XGBoostModel> getBestXGBs(int topN) {
            ArrayList<XGBoostModel> xgbs = new ArrayList<XGBoostModel>();
            for (Model model : this.getTrainedModels()) {
                if (model instanceof XGBoostModel) {
                    xgbs.add((XGBoostModel)model);
                }
                if (xgbs.size() == topN) break;
            }
            return xgbs;
        }

        @Override
        protected boolean canRun() {
            return super.canRun() && this.getBestXGBs(1).size() > 0;
        }

        public XGBoostExploitationStep(String id, int weight, AutoML autoML, boolean emulateLightGBM) {
            super(Algo.XGBoost, id, weight, autoML);
            this._emulateLightGBM = emulateLightGBM;
        }
    }

    static abstract class XGBoostGridStep
    extends ModelingStep.GridStep<XGBoostModel> {
        boolean _emulateLightGBM;

        public XGBoostGridStep(String id, int weight, AutoML autoML, boolean emulateLightGBM) {
            super(Algo.XGBoost, id, weight, autoML);
            this._emulateLightGBM = emulateLightGBM;
        }

        XGBoostModel.XGBoostParameters prepareModelParameters() {
            return XGBoostSteps.prepareModelParameters(this.aml(), this._emulateLightGBM);
        }
    }

    static abstract class XGBoostModelStep
    extends ModelingStep.ModelStep<XGBoostModel> {
        boolean _emulateLightGBM;

        XGBoostModelStep(String id, int weight, AutoML autoML, boolean emulateLightGBM) {
            super(Algo.XGBoost, id, weight, autoML);
            this._emulateLightGBM = emulateLightGBM;
        }

        XGBoostModel.XGBoostParameters prepareModelParameters() {
            return XGBoostSteps.prepareModelParameters(this.aml(), this._emulateLightGBM);
        }
    }
}

