/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.automl.modeling;

import ai.h2o.automl.Algo;
import ai.h2o.automl.AutoML;
import ai.h2o.automl.ModelSelectionStrategies;
import ai.h2o.automl.ModelSelectionStrategy;
import ai.h2o.automl.ModelingStep;
import ai.h2o.automl.ModelingSteps;
import ai.h2o.automl.Models;
import ai.h2o.automl.events.EventLogEntry;
import hex.Model;
import hex.ModelParametersBuilderFactory;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import hex.grid.HyperSpaceWalker;
import hex.grid.SequentialWalker;
import hex.tree.xgboost.XGBoostModel;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import water.Job;
import water.Key;

public class XGBoostSteps
extends ModelingSteps {
    static final String NAME = Algo.XGBoost.name();
    private final ModelingStep[] defaults = new XGBoostModelStep[]{new XGBoostModelStep("def_1", this.aml(), false){

        @Override
        public XGBoostModel.XGBoostParameters prepareModelParameters() {
            XGBoostModel.XGBoostParameters params = super.prepareModelParameters();
            params._max_depth = 10;
            params._min_rows = 5.0;
            params._sample_rate = 0.6;
            params._col_sample_rate = 0.8;
            params._col_sample_rate_per_tree = 0.8;
            if (this._emulateLightGBM) {
                params._max_leaves = 1 << params._max_depth;
                params._max_depth *= 2;
            }
            return params;
        }
    }, new XGBoostModelStep("def_2", this.aml(), false){

        @Override
        public XGBoostModel.XGBoostParameters prepareModelParameters() {
            XGBoostModel.XGBoostParameters params = super.prepareModelParameters();
            params._max_depth = 15;
            params._min_rows = 10.0;
            params._sample_rate = 0.6;
            params._col_sample_rate = 0.8;
            params._col_sample_rate_per_tree = 0.8;
            if (this._emulateLightGBM) {
                params._max_leaves = 1 << params._max_depth;
                params._max_depth *= 2;
            }
            return params;
        }
    }, new XGBoostModelStep("def_3", this.aml(), false){

        @Override
        public XGBoostModel.XGBoostParameters prepareModelParameters() {
            XGBoostModel.XGBoostParameters params = super.prepareModelParameters();
            params._max_depth = 5;
            params._min_rows = 3.0;
            params._sample_rate = 0.8;
            params._col_sample_rate = 0.8;
            params._col_sample_rate_per_tree = 0.8;
            if (this._emulateLightGBM) {
                params._max_leaves = 1 << params._max_depth;
                params._max_depth *= 2;
            }
            return params;
        }
    }};
    private final ModelingStep[] grids = new XGBoostGridStep[]{new DefaultXGBoostGridStep("grid_1", this.aml()), new XGBoostGBLinearGridStep("grid_gblinear", this.aml())};
    private final ModelingStep[] exploitation = new ModelingStep[]{new XGBoostExploitationStep("lr_annealing", this.aml(), false){
        Key<Models> resultKey;
        {
            this.resultKey = null;
        }

        @Override
        protected Job<Models> startTraining(Key result, double maxRuntimeSecs) {
            this.resultKey = result;
            XGBoostModel bestXGB = this.getBestXGB();
            this.aml().eventLog().info(EventLogEntry.Stage.ModelSelection, "Retraining best XGBoost with learning rate annealing: " + bestXGB._key);
            XGBoostModel.XGBoostParameters params = (XGBoostModel.XGBoostParameters)((XGBoostModel.XGBoostParameters)bestXGB._input_parms).clone();
            params._max_runtime_secs = 0.0;
            params._learn_rate_annealing = 0.99;
            this.initTimeConstraints((Model.Parameters)params, maxRuntimeSecs);
            this.setStoppingCriteria((Model.Parameters)params, (Model.Parameters)new XGBoostModel.XGBoostParameters());
            return this.asModelsJob(this.startModel(Key.make((String)(result + "_model")), params), (Key<Models>)result);
        }

        @Override
        protected ModelSelectionStrategy getSelectionStrategy() {
            return (originalModels, newModels) -> new ModelSelectionStrategies.KeepBestN(1, () -> this.makeTmpLeaderboard(Objects.toString(this.resultKey, this._provider + "_" + this._id))).select(new Key[]{this.getBestXGB()._key}, newModels);
        }
    }, new XGBoostExploitationStep("lr_search", this.aml(), false){
        Key resultKey;
        {
            this.resultKey = null;
        }

        @Override
        protected ModelSelectionStrategy getSelectionStrategy() {
            return (originalModels, newModels) -> new ModelSelectionStrategies.KeepBestN(1, () -> this.makeTmpLeaderboard(Objects.toString(this.resultKey, this._provider + "_" + this._id))).select(new Key[]{this.getBestXGB()._key}, newModels);
        }

        @Override
        protected Job<Models> startTraining(Key result, double maxRuntimeSecs) {
            this.resultKey = result;
            XGBoostModel bestXGB = this.getBestXGBs(1).get(0);
            this.aml().eventLog().info(EventLogEntry.Stage.ModelSelection, "Applying learning rate search on best XGBoost: " + bestXGB._key);
            XGBoostModel.XGBoostParameters params = (XGBoostModel.XGBoostParameters)((XGBoostModel.XGBoostParameters)bestXGB._input_parms).clone();
            XGBoostModel.XGBoostParameters defaults = new XGBoostModel.XGBoostParameters();
            params._max_runtime_secs = 0.0;
            this.initTimeConstraints((Model.Parameters)params, 0.0);
            this.setStoppingCriteria((Model.Parameters)params, (Model.Parameters)defaults);
            int sti = params._score_tree_interval;
            Object[][] hyperParams = new Object[][]{{"_learn_rate", "_score_tree_interval"}, {0.5, sti}, {0.2, 2 * sti}, {0.1, 3 * sti}, {0.05, 4 * sti}, {0.02, 5 * sti}, {0.01, 6 * sti}, {0.005, 7 * sti}, {0.002, 8 * sti}, {0.001, 9 * sti}, {5.0E-4, 10 * sti}};
            this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, "AutoML: starting " + this.resultKey + " model training").setNamedValue("start_" + this._provider + "_" + this._id, new Date(), EventLogEntry.epochFormat.get());
            return this.asModelsJob(GridSearch.startGridSearch((Key)Key.make((String)(result + "_grid")), (HyperSpaceWalker)new SequentialWalker((Model.Parameters)params, hyperParams, (ModelParametersBuilderFactory)new GridSearch.SimpleParametersBuilderFactory(), new HyperSpaceSearchCriteria.SequentialSearchCriteria(HyperSpaceSearchCriteria.StoppingCriteria.create().maxRuntimeSecs((double)((int)maxRuntimeSecs)).stoppingMetric(params._stopping_metric).stoppingRounds(3).stoppingTolerance(params._stopping_tolerance).build())), (int)1), (Key<Models>)result);
        }
    }};

    static XGBoostModel.XGBoostParameters prepareModelParameters(AutoML aml, boolean emulateLightGBM) {
        XGBoostModel.XGBoostParameters params = new XGBoostModel.XGBoostParameters();
        if (emulateLightGBM) {
            params._tree_method = XGBoostModel.XGBoostParameters.TreeMethod.hist;
            params._grow_policy = XGBoostModel.XGBoostParameters.GrowPolicy.lossguide;
        }
        params._score_tree_interval = 5;
        params._ntrees = 10000;
        return params;
    }

    public XGBoostSteps(AutoML autoML) {
        super(autoML);
    }

    @Override
    public String getProvider() {
        return NAME;
    }

    @Override
    protected ModelingStep[] getDefaultModels() {
        return this.defaults;
    }

    @Override
    protected ModelingStep[] getGrids() {
        return this.grids;
    }

    @Override
    protected ModelingStep[] getOptionals() {
        return this.exploitation;
    }

    static class XGBoostGBLinearGridStep
    extends XGBoostGridStep {
        public XGBoostGBLinearGridStep(String id, AutoML autoML) {
            super(id, autoML, false);
        }

        @Override
        public XGBoostModel.XGBoostParameters prepareModelParameters() {
            return XGBoostSteps.prepareModelParameters(this.aml(), false);
        }

        @Override
        public Map<String, Object[]> prepareSearchParameters() {
            HashMap<String, Object[]> searchParams = new HashMap<String, Object[]>();
            searchParams.put("_booster", new XGBoostModel.XGBoostParameters.Booster[]{XGBoostModel.XGBoostParameters.Booster.gblinear});
            searchParams.put("_reg_lambda", new Float[]{Float.valueOf(0.001f), Float.valueOf(0.01f), Float.valueOf(0.1f), Float.valueOf(1.0f), Float.valueOf(10.0f), Float.valueOf(100.0f)});
            searchParams.put("_reg_alpha", new Float[]{Float.valueOf(0.001f), Float.valueOf(0.01f), Float.valueOf(0.1f), Float.valueOf(0.5f), Float.valueOf(1.0f)});
            return searchParams;
        }
    }

    static class DefaultXGBoostGridStep
    extends XGBoostGridStep {
        public DefaultXGBoostGridStep(String id, AutoML autoML) {
            super(id, autoML, false);
        }

        @Override
        public XGBoostModel.XGBoostParameters prepareModelParameters() {
            XGBoostModel.XGBoostParameters params = super.prepareModelParameters();
            params._scale_pos_weight = new XGBoostModel.XGBoostParameters()._scale_pos_weight;
            return params;
        }

        @Override
        public Map<String, Object[]> prepareSearchParameters() {
            HashMap<String, Object[]> searchParams = new HashMap<String, Object[]>();
            if (this._emulateLightGBM) {
                searchParams.put("_max_leaves", new Integer[]{32, 1024, 32768, 0x100000});
                searchParams.put("_max_depth", new Integer[]{10, 20, 50});
            } else {
                searchParams.put("_max_depth", new Integer[]{3, 6, 9, 12, 15});
                if (this.aml().getWeightsColumn() == null || this.aml().getWeightsColumn().isInt()) {
                    searchParams.put("_min_rows", new Double[]{1.0, 3.0, 5.0, 10.0, 15.0, 20.0});
                } else {
                    searchParams.put("_min_rows", new Double[]{0.01, 0.1, 1.0, 3.0, 5.0, 10.0, 15.0, 20.0});
                }
            }
            searchParams.put("_sample_rate", new Double[]{0.6, 0.8, 1.0});
            searchParams.put("_col_sample_rate", new Double[]{0.6, 0.8, 1.0});
            searchParams.put("_col_sample_rate_per_tree", new Double[]{0.7, 0.8, 0.9, 1.0});
            searchParams.put("_booster", new XGBoostModel.XGBoostParameters.Booster[]{XGBoostModel.XGBoostParameters.Booster.gbtree, XGBoostModel.XGBoostParameters.Booster.gbtree, XGBoostModel.XGBoostParameters.Booster.dart});
            searchParams.put("_reg_lambda", new Float[]{Float.valueOf(0.001f), Float.valueOf(0.01f), Float.valueOf(0.1f), Float.valueOf(1.0f), Float.valueOf(10.0f), Float.valueOf(100.0f)});
            searchParams.put("_reg_alpha", new Float[]{Float.valueOf(0.001f), Float.valueOf(0.01f), Float.valueOf(0.1f), Float.valueOf(0.5f), Float.valueOf(1.0f)});
            if (this.aml().getBuildSpec().build_control.balance_classes && this.aml().getDistributionFamily().equals((Object)DistributionFamily.bernoulli)) {
                double[] dist = this.aml().getClassDistribution();
                float negPosRatio = (float)(dist[0] / dist[1]);
                float imbalanceRatio = negPosRatio < 1.0f ? 1.0f / negPosRatio : negPosRatio;
                searchParams.put("_scale_pos_weight", new Float[]{Float.valueOf(1.0f), Float.valueOf(negPosRatio)});
                searchParams.put("_max_delta_step", new Float[]{Float.valueOf(0.0f), Float.valueOf(Math.min(5.0f, imbalanceRatio / 2.0f)), Float.valueOf(Math.min(10.0f, imbalanceRatio))});
            }
            return searchParams;
        }
    }

    static abstract class XGBoostExploitationStep
    extends ModelingStep.SelectionStep<XGBoostModel> {
        boolean _emulateLightGBM;

        protected XGBoostModel getBestXGB() {
            return this.getBestXGBs(1).get(0);
        }

        protected List<XGBoostModel> getBestXGBs(int topN) {
            ArrayList<XGBoostModel> xgbs = new ArrayList<XGBoostModel>();
            for (Model model : this.getTrainedModels()) {
                if (model instanceof XGBoostModel) {
                    xgbs.add((XGBoostModel)model);
                }
                if (xgbs.size() == topN) break;
            }
            return xgbs;
        }

        @Override
        public boolean canRun() {
            return super.canRun() && this.getBestXGBs(1).size() > 0;
        }

        public XGBoostExploitationStep(String id, AutoML autoML, boolean emulateLightGBM) {
            super(NAME, Algo.XGBoost, id, autoML);
            this._emulateLightGBM = emulateLightGBM;
            if (autoML.getBuildSpec().build_models.exploitation_ratio > 0.0) {
                this._ignoredConstraints = new AutoML.Constraint[]{AutoML.Constraint.MODEL_COUNT};
            }
        }
    }

    static abstract class XGBoostGridStep
    extends ModelingStep.GridStep<XGBoostModel> {
        boolean _emulateLightGBM;

        public XGBoostGridStep(String id, AutoML autoML, boolean emulateLightGBM) {
            super(NAME, Algo.XGBoost, id, autoML);
            this._emulateLightGBM = emulateLightGBM;
        }

        public XGBoostModel.XGBoostParameters prepareModelParameters() {
            return XGBoostSteps.prepareModelParameters(this.aml(), this._emulateLightGBM);
        }
    }

    static abstract class XGBoostModelStep
    extends ModelingStep.ModelStep<XGBoostModel> {
        boolean _emulateLightGBM;

        XGBoostModelStep(String id, AutoML autoML, boolean emulateLightGBM) {
            super(NAME, Algo.XGBoost, id, autoML);
            this._emulateLightGBM = emulateLightGBM;
        }

        public XGBoostModel.XGBoostParameters prepareModelParameters() {
            XGBoostModel.XGBoostParameters params = XGBoostSteps.prepareModelParameters(this.aml(), this._emulateLightGBM);
            if (this.aml().getBuildSpec().build_control.balance_classes && this.aml().getDistributionFamily().equals((Object)DistributionFamily.bernoulli)) {
                double[] dist = this.aml().getClassDistribution();
                params._scale_pos_weight = (float)(dist[0] / dist[1]);
            }
            return params;
        }
    }
}

