/*
 * Decompiled with CFR 0.152.
 */
package hivemall.xgboost;

import hivemall.UDTFWithOptions;
import hivemall.annotations.VisibleForTesting;
import hivemall.utils.collections.lists.FloatArrayList;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.OptionUtils;
import hivemall.utils.math.MathUtils;
import hivemall.xgboost.utils.DMatrixBuilder;
import hivemall.xgboost.utils.DenseDMatrixBuilder;
import hivemall.xgboost.utils.NativeLibLoader;
import hivemall.xgboost.utils.SparseDMatrixBuilder;
import hivemall.xgboost.utils.XGBoostUtils;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import matrix4j.utils.lang.ArrayUtils;
import matrix4j.utils.lang.Primitives;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.Text;

@Description(name="train_xgboost", value="_FUNC_(array<string|double> features, <int|double> target, const string options) - Returns a relation consists of <string model_id, array<string> pred_model>", extended="SELECT \n  train_xgboost(features, label, '-objective binary:logistic -iters 10') \n    as (model_id, model)\nfrom (\n  select features, label\n  from xgb_input\n  cluster by rand(43) -- shuffle\n) shuffled;")
public class XGBoostTrainUDTF
extends UDTFWithOptions {
    private static final Log logger = LogFactory.getLog(XGBoostTrainUDTF.class);
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector featureElemOI;
    private PrimitiveObjectInspector targetOI;
    private boolean denseInput;
    private DMatrixBuilder matrixBuilder;
    private FloatArrayList labels;
    @Nonnull
    protected final Map<String, Object> params = new HashMap<String, Object>();
    protected int numClass;
    protected ObjectiveType objectiveType = null;

    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("num_round", "iters", true, "Number of boosting iterations [default: 10]");
        opts.addOption("maximize_evaluation_metrics", true, "Maximize evaluation metrics [default: false]");
        opts.addOption("num_early_stopping_rounds", true, "Minimum rounds required for early stopping [default: 0]");
        opts.addOption("validation_ratio", true, "Validation ratio in range [0.0,1.0] [default: 0.2]");
        opts.addOption("booster", true, "Set a booster to use, gbtree or gblinear or dart. [default: gbree]");
        opts.addOption("silent", true, "Deprecated. Please use verbosity instead. 0 means printing running messages, 1 means silent mode [default: 1]");
        opts.addOption("verbosity", true, "Verbosity of printing messages. Choices: 0 (silent), 1 (warning), 2 (info), 3 (debug). [default: 0]");
        opts.addOption("disable_default_eval_metric", true, "NFlag to disable default metric. Set to >0 to disable. [default: 0]");
        opts.addOption("num_pbuffer", true, "Size of prediction buffer [default: set automatically by xgboost]");
        opts.addOption("num_feature", true, "Feature dimension used in boosting [default: set automatically by xgboost]");
        opts.addOption("lambda", "reg_lambda", true, "L2 regularization term on weights. Increasing this value will make model more conservative. [default: 1.0 for gbtree, 0.0 for gblinear]");
        opts.addOption("alpha", "reg_alpha", true, "L1 regularization term on weights. Increasing this value will make model more conservative. [default: 0.0]");
        opts.addOption("updater", true, "A comma-separated string that defines the sequence of tree updaters to run. For a full list of valid inputs, please refer to XGBoost Parameters. [default: 'grow_colmaker,prune' for gbtree, 'shotgun' for gblinear]");
        opts.addOption("eta", "learning_rate", true, "Step size shrinkage used in update to prevents overfitting [default: 0.3]");
        opts.addOption("gamma", "min_split_loss", true, "Minimum loss reduction required to make a further partition on a leaf node of the tree. [default: 0.0]");
        opts.addOption("max_depth", true, "Max depth of decision tree [default: 6]");
        opts.addOption("min_child_weight", true, "Minimum sum of instance weight (hessian) needed in a child [default: 1.0]");
        opts.addOption("max_delta_step", true, "Maximum delta step we allow each tree's weight estimation to be [default: 0]");
        opts.addOption("subsample", true, "Subsample ratio of the training instance in range (0.0,1.0] [default: 1.0]");
        opts.addOption("colsample_bytree", true, "Subsample ratio of columns when constructing each tree [default: 1.0]");
        opts.addOption("colsample_bylevel", true, "Subsample ratio of columns for each level [default: 1.0]");
        opts.addOption("colsample_bynode", true, "Subsample ratio of columns for each node [default: 1.0]");
        opts.addOption("tree_method", true, "The tree construction algorithm used in XGBoost. [default: auto, Choices: auto, exact, approx, hist]");
        opts.addOption("sketch_eps", true, "This roughly translates into O(1 / sketch_eps) number of bins. \nCompared to directly select number of bins, this comes with theoretical guarantee with sketch accuracy.\nOnly used for tree_method=approx. Usually user does not have to tune this.  [default: 0.03]");
        opts.addOption("scale_pos_weight", true, "ontrol the balance of positive and negative weights, useful for unbalanced classes. A typical value to consider: sum(negative instances) / sum(positive instances) [default: 1.0]");
        opts.addOption("refresh_leaf", true, "This is a parameter of the refresh updater plugin. When this flag is 1, tree leafs as well as tree nodes\u2019 stats are updated. When it is 0, only node stats are updated. [default: 1]");
        opts.addOption("process_type", true, "A type of boosting process to run. [Choices: default, update]");
        opts.addOption("grow_policy", true, "Controls a way new nodes are added to the tree. Currently supported only if tree_method is set to hist. [default: depthwise, Choices: depthwise, lossguide]");
        opts.addOption("max_leaves", true, "Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set. [default: 0]");
        opts.addOption("max_bin", true, "Maximum number of discrete bins to bucket continuous features. Only used if tree_method is set to hist. [default: 256]");
        opts.addOption("num_parallel_tree", true, "Number of parallel trees constructed during each iteration. This option is used to support boosted random forest. Usually no need to tune (default 1 is enough) for gradient boosting trees. [default: 1]");
        opts.addOption("sample_type", true, "Type of sampling algorithm. [Choices: uniform (default), weighted]");
        opts.addOption("normalize_type", true, "Type of normalization algorithm. [Choices: tree (default), forest]");
        opts.addOption("rate_drop", true, "Dropout rate in range [0.0, 1.0]. [default: 0.0]");
        opts.addOption("one_drop", true, "When this flag is enabled, at least one tree is always dropped during the dropout. 0 or 1. [default: 0]");
        opts.addOption("skip_drop", true, "Probability of skipping the dropout procedure during a boosting iteration in range [0.0, 1.0]. [default: 0.0]");
        opts.addOption("lambda_bias", true, "L2 regularization term on bias [default: 0.0]");
        opts.addOption("feature_selector", true, "Feature selection and ordering method. [Choices: cyclic (default), shuffle, random, greedy, thrifty]");
        opts.addOption("top_k", true, "The number of top features to select in greedy and thrifty feature selector. The value of 0 means using all the features. [default: 0]");
        opts.addOption("tweedie_variance_power", true, "Parameter that controls the variance of the Tweedie distribution in range [1.0, 2.0]. [default: 1.5]");
        opts.addOption("objective", true, "Specifies the learning task and the corresponding learning objective. Examples: reg:linear, reg:logistic, multi:softmax. For a full list of valid inputs, refer to XGBoost Parameters. [default: reg:linear]");
        opts.addOption("base_score", true, "Initial prediction score of all instances, global bias [default: 0.5]");
        opts.addOption("eval_metric", true, "Evaluation metrics for validation data. A default metric is assigned according to the objective:\n- rmse: for regression\n- error: for classification\n- map: for ranking\nFor a list of valid inputs, see XGBoost Parameters.");
        opts.addOption("seed", true, "Random number seed. [default: 43]");
        opts.addOption("num_class", true, "Number of classes to classify");
        return opts;
    }

    @Nonnull
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl;
        if (argOIs.length >= 3) {
            String rawArgs = HiveUtils.getConstString((ObjectInspector[])argOIs, (int)2);
            cl = this.parseOptions(rawArgs);
        } else {
            cl = this.parseOptions("");
        }
        String objective = cl.getOptionValue("objective");
        if (objective == null) {
            this.showHelp("Please provide \"-objective XXX\" option in the 3rd argument.\n\nHere is the list of supported objectives: \n - Regression:\n {reg:squarederror, reg:logistic, reg:gamma, reg:tweedie}\n - Binary classification: {binary:logistic, binary:logitraw, binary:hinge}\n - Multiclass classification:\n {multi:softmax, multi:softprob}\n - Ranking:\n {rank:pairwise, rank:ndcg, rank:map}\n - Other:\n {count:poisson, survival:cox}");
        }
        if (objective.equals("reg:squarederror")) {
            objective = "reg:linear";
        }
        String booster = cl.getOptionValue("booster", "gbtree");
        int numRound = Primitives.parseInt((String)cl.getOptionValue("num_round"), (int)10);
        this.params.put("num_round", numRound);
        this.params.put("maximize_evaluation_metrics", Primitives.parseBoolean((String)cl.getOptionValue("maximize_evaluation_metrics"), (boolean)false));
        this.params.put("num_early_stopping_rounds", Primitives.parseInt((String)cl.getOptionValue("num_early_stopping_rounds"), (int)0));
        double validationRatio = Primitives.parseDouble((String)cl.getOptionValue("validation_ratio"), (double)0.2);
        if (validationRatio < 0.0 || validationRatio >= 1.0) {
            throw new UDFArgumentException("Invalid validation_ratio=" + validationRatio);
        }
        this.params.put("validation_ratio", validationRatio);
        this.params.put("booster", booster);
        this.params.put("silent", Primitives.parseInt((String)cl.getOptionValue("silent"), (int)1));
        this.params.put("verbosity", Primitives.parseInt((String)cl.getOptionValue("verbosity"), (int)0));
        this.params.put("nthread", Primitives.parseInt((String)cl.getOptionValue("nthread"), (int)1));
        this.params.put("disable_default_eval_metric", Primitives.parseInt((String)cl.getOptionValue("disable_default_eval_metric"), (int)0));
        if (cl.hasOption("num_pbuffer")) {
            this.params.put("num_pbuffer", Integer.valueOf(cl.getOptionValue("num_pbuffer")));
        }
        if (cl.hasOption("num_feature")) {
            this.params.put("num_feature", Integer.valueOf(cl.getOptionValue("num_feature")));
        }
        if (booster.equals("gbtree")) {
            this.params.put("eta", Primitives.parseDouble((String)cl.getOptionValue("eta"), (double)0.3));
            this.params.put("gamma", Primitives.parseDouble((String)cl.getOptionValue("gamma"), (double)0.0));
            this.params.put("max_depth", Primitives.parseInt((String)cl.getOptionValue("max_depth"), (int)6));
            this.params.put("min_child_weight", Primitives.parseDouble((String)cl.getOptionValue("min_child_weight"), (double)1.0));
            this.params.put("max_delta_step", Primitives.parseDouble((String)cl.getOptionValue("max_delta_step"), (double)0.0));
            this.params.put("subsample", Primitives.parseDouble((String)cl.getOptionValue("subsample"), (double)1.0));
            this.params.put("colsamle_bytree", Primitives.parseDouble((String)cl.getOptionValue("colsample_bytree"), (double)1.0));
            this.params.put("colsamle_bylevel", Primitives.parseDouble((String)cl.getOptionValue("colsamle_bylevel"), (double)1.0));
            this.params.put("colsamle_bynode", Primitives.parseDouble((String)cl.getOptionValue("colsamle_bynode"), (double)1.0));
            this.params.put("lambda", Primitives.parseDouble((String)cl.getOptionValue("lambda"), (double)1.0));
            this.params.put("alpha", Primitives.parseDouble((String)cl.getOptionValue("alpha"), (double)0.0));
            this.params.put("tree_method", cl.getOptionValue("tree_method", "auto"));
            this.params.put("sketch_eps", Primitives.parseDouble((String)cl.getOptionValue("sketch_eps"), (double)0.03));
            this.params.put("scale_pos_weight", Primitives.parseDouble((String)cl.getOptionValue("scale_pos_weight"), (double)1.0));
            this.params.put("updater", cl.getOptionValue("updater", "grow_colmaker,prune"));
            this.params.put("refresh_leaf", Primitives.parseInt((String)cl.getOptionValue("refresh_leaf"), (int)1));
            this.params.put("process_type", cl.getOptionValue("process_type", "default"));
            this.params.put("grow_policy", cl.getOptionValue("grow_policy", "depthwise"));
            this.params.put("max_leaves", Primitives.parseInt((String)cl.getOptionValue("max_leaves"), (int)0));
            this.params.put("max_bin", Primitives.parseInt((String)cl.getOptionValue("max_bin"), (int)256));
            this.params.put("num_parallel_tree", Primitives.parseInt((String)cl.getOptionValue("num_parallel_tree"), (int)1));
        }
        if (booster.equals("dart")) {
            this.params.put("sample_type", cl.getOptionValue("sample_type", "uniform"));
            this.params.put("normalize_type", cl.getOptionValue("normalize_type", "tree"));
            this.params.put("rate_drop", Primitives.parseDouble((String)cl.getOptionValue("rate_drop"), (double)0.0));
            this.params.put("one_drop", Primitives.parseInt((String)cl.getOptionValue("one_drop"), (int)0));
            this.params.put("skip_drop", Primitives.parseDouble((String)cl.getOptionValue("skip_drop"), (double)0.0));
        }
        if (booster.equals("gblinear")) {
            this.params.put("lambda", Primitives.parseDouble((String)cl.getOptionValue("lambda"), (double)0.0));
            this.params.put("lambda_bias", Primitives.parseDouble((String)cl.getOptionValue("lambda_bias"), (double)0.0));
            this.params.put("alpha", Primitives.parseDouble((String)cl.getOptionValue("alpha"), (double)0.0));
            this.params.put("updater", cl.getOptionValue("updater", "shotgun"));
            this.params.put("feature_selector", cl.getOptionValue("feature_selector", "cyclic"));
            this.params.put("top_k", Primitives.parseInt((String)cl.getOptionValue("top_k"), (int)0));
        }
        if (objective.equals("reg:tweedie")) {
            this.params.put("tweedie_variance_power", Primitives.parseDouble((String)cl.getOptionValue("tweedie_variance_power"), (double)1.5));
        }
        if (objective.equals("count:poisson")) {
            this.params.put("max_delta_step", Primitives.parseDouble((String)cl.getOptionValue("max_delta_step"), (double)0.7));
        }
        this.params.put("objective", objective);
        this.params.put("base_score", Primitives.parseDouble((String)cl.getOptionValue("base_score"), (double)0.5));
        if (cl.hasOption("eval_metric")) {
            this.params.put("eval_metric", cl.getOptionValue("eval_metric"));
        }
        this.params.put("seed", Primitives.parseLong((String)cl.getOptionValue("seed"), (long)43L));
        if (cl.hasOption("num_class")) {
            this.numClass = Integer.parseInt(cl.getOptionValue("num_class"));
            this.params.put("num_class", this.numClass);
        } else if (objective.startsWith("multi:")) {
            throw new UDFArgumentException("-num_class is required for multiclass classification");
        }
        if (logger.isInfoEnabled()) {
            logger.info((Object)("XGboost training hyperparameters: " + this.params.toString()));
        }
        this.objectiveType = ObjectiveType.resolve(objective);
        return cl;
    }

    public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length != 2 && argOIs.length != 3) {
            this.showHelp("Invalid argment length=" + argOIs.length);
        }
        this.processOptions(argOIs);
        ListObjectInspector listOI = HiveUtils.asListOI((ObjectInspector[])argOIs, (int)0);
        ObjectInspector elemOI = listOI.getListElementObjectInspector();
        this.featureListOI = listOI;
        if (HiveUtils.isNumberOI((ObjectInspector)elemOI)) {
            this.featureElemOI = HiveUtils.asDoubleCompatibleOI((ObjectInspector)elemOI);
            this.denseInput = true;
            this.matrixBuilder = new DenseDMatrixBuilder(8192);
        } else if (HiveUtils.isStringOI((ObjectInspector)elemOI)) {
            this.featureElemOI = HiveUtils.asStringOI((ObjectInspector)elemOI);
            this.denseInput = false;
            this.matrixBuilder = new SparseDMatrixBuilder(8192);
        } else {
            throw new UDFArgumentException("train_xgboost takes array<double> or array<string> for the first argument: " + listOI.getTypeName());
        }
        this.targetOI = HiveUtils.asDoubleCompatibleOI((ObjectInspector[])argOIs, (int)1);
        this.labels = new FloatArrayList(1024);
        ArrayList<String> fieldNames = new ArrayList<String>(2);
        ArrayList<Object> fieldOIs = new ArrayList<Object>(2);
        fieldNames.add("model_id");
        fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        fieldNames.add("model");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    protected float processTargetValue(float target) throws HiveException {
        switch (this.objectiveType) {
            case binary: {
                if (target != -1.0f && target != 0.0f && target != 1.0f) {
                    throw new UDFArgumentException("Invalid label value for classification: " + target);
                }
                return target > 0.0f ? 1.0f : 0.0f;
            }
            case multiclass: {
                int clazz = (int)target;
                if ((float)clazz != target) {
                    throw new UDFArgumentException("Invalid target value for class label: " + target);
                }
                if (clazz < 0 || clazz >= this.numClass) {
                    throw new UDFArgumentException("target must be {0.0, ..., " + String.format("%.1f", (double)this.numClass - 1.0) + "}: " + target);
                }
                return target;
            }
        }
        return target;
    }

    public void process(@Nonnull Object[] args) throws HiveException {
        if (args[0] == null) {
            throw new HiveException("array<double> features was null");
        }
        this.parseFeatures(args[0], this.matrixBuilder);
        float target = PrimitiveObjectInspectorUtils.getFloat((Object)args[1], (PrimitiveObjectInspector)this.targetOI);
        this.labels.add(this.processTargetValue(target));
    }

    private void parseFeatures(@Nonnull Object argObj, @Nonnull DMatrixBuilder builder) {
        if (this.denseInput) {
            int length = this.featureListOI.getListLength(argObj);
            for (int i = 0; i < length; ++i) {
                Object o = this.featureListOI.getListElement(argObj, i);
                if (o == null) continue;
                float v = PrimitiveObjectInspectorUtils.getFloat((Object)o, (PrimitiveObjectInspector)this.featureElemOI);
                builder.nextColumn(i, v);
            }
        } else {
            int length = this.featureListOI.getListLength(argObj);
            for (int i = 0; i < length; ++i) {
                Object o = this.featureListOI.getListElement(argObj, i);
                if (o == null) continue;
                String fv = o.toString();
                builder.nextColumn(fv);
            }
        }
        builder.nextRow();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void close() throws HiveException {
        DMatrix dmatrix = null;
        Booster booster = null;
        try {
            block7: {
                int round;
                block6: {
                    dmatrix = this.matrixBuilder.buildMatrix(this.labels.toArray(true));
                    this.matrixBuilder = null;
                    this.labels = null;
                    round = OptionUtils.getInt(this.params, (String)"num_round");
                    int earlyStoppingRounds = OptionUtils.getInt(this.params, (String)"num_early_stopping_rounds");
                    if (earlyStoppingRounds <= 0) break block6;
                    double validationRatio = OptionUtils.getDouble(this.params, (String)"validation_ratio");
                    long seed = OptionUtils.getLong(this.params, (String)"seed");
                    int numRows = (int)dmatrix.rowNum();
                    int[] rows = MathUtils.permutation((int)numRows);
                    ArrayUtils.shuffle((int[])rows, (Random)new Random(seed));
                    int numTest = (int)((double)numRows * validationRatio);
                    DMatrix dtrain = null;
                    DMatrix dtest = null;
                    try {
                        dtest = dmatrix.slice(Arrays.copyOf(rows, numTest));
                        dtrain = dmatrix.slice(Arrays.copyOfRange(rows, numTest, rows.length));
                        booster = XGBoostTrainUDTF.train(dtrain, dtest, round, earlyStoppingRounds, this.params);
                    }
                    catch (Throwable throwable) {
                        XGBoostUtils.close(dtrain);
                        XGBoostUtils.close(dtest);
                        throw throwable;
                    }
                    XGBoostUtils.close(dtrain);
                    XGBoostUtils.close(dtest);
                    break block7;
                }
                booster = XGBoostTrainUDTF.train(dmatrix, round, this.params);
            }
            this.onFinishTraining(booster);
            String modelId = XGBoostTrainUDTF.generateUniqueModelId();
            Text predModel = XGBoostUtils.serializeBooster(booster);
            logger.info((Object)("model_id:" + modelId.toString() + ", size:" + predModel.getLength()));
            this.forward(new Object[]{modelId, predModel});
        }
        catch (Throwable e) {
            try {
                throw new HiveException(e);
            }
            catch (Throwable throwable) {
                XGBoostUtils.close(dmatrix);
                XGBoostUtils.close(booster);
                throw throwable;
            }
        }
        XGBoostUtils.close(dmatrix);
        XGBoostUtils.close(booster);
    }

    @VisibleForTesting
    protected void onFinishTraining(@Nonnull Booster booster) {
    }

    @Nonnull
    private static Booster train(@Nonnull DMatrix dtrain, @Nonnegative int round, @Nonnull Map<String, Object> params) throws NoSuchMethodException, IllegalAccessException, InvocationTargetException, InstantiationException, XGBoostError {
        Booster booster = XGBoostUtils.createBooster(dtrain, params);
        for (int iter = 0; iter < round; ++iter) {
            booster.update(dtrain, iter);
        }
        return booster;
    }

    @Nonnull
    private static Booster train(@Nonnull DMatrix dtrain, @Nonnull DMatrix dtest, @Nonnegative int round, @Nonnegative int earlyStoppingRounds, @Nonnull Map<String, Object> params) throws NoSuchMethodException, IllegalAccessException, InvocationTargetException, InstantiationException, XGBoostError {
        Booster booster = XGBoostUtils.createBooster(dtrain, params);
        boolean maximizeEvaluationMetrics = OptionUtils.getBoolean(params, (String)"maximize_evaluation_metrics");
        float bestScore = maximizeEvaluationMetrics ? -3.4028235E38f : Float.MAX_VALUE;
        int bestIteration = 0;
        float[] metricsOut = new float[1];
        for (int iter = 0; iter < round; ++iter) {
            booster.update(dtrain, iter);
            String evalInfo = booster.evalSet(new DMatrix[]{dtest}, new String[]{"test"}, iter, metricsOut);
            logger.info((Object)evalInfo);
            float score = metricsOut[0];
            if (maximizeEvaluationMetrics) {
                if (score > bestScore) {
                    bestScore = score;
                    bestIteration = iter;
                }
            } else if (score < bestScore) {
                bestScore = score;
                bestIteration = iter;
            }
            if (!XGBoostTrainUDTF.shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) continue;
            logger.info((Object)String.format("early stopping after %d rounds away from the best iteration", earlyStoppingRounds));
            break;
        }
        return booster;
    }

    private static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIteration) {
        return iter - bestIteration >= earlyStoppingRounds;
    }

    @Nonnull
    private static String generateUniqueModelId() {
        return "xgbmodel-" + HadoopUtils.getUniqueTaskIdString();
    }

    static {
        NativeLibLoader.initXGBoost();
    }

    public static enum ObjectiveType {
        regression,
        binary,
        multiclass,
        rank,
        other;


        @Nonnull
        public static ObjectiveType resolve(@Nonnull String objective) {
            if (objective.startsWith("reg:")) {
                return regression;
            }
            if (objective.startsWith("binary:")) {
                return binary;
            }
            if (objective.startsWith("multi:")) {
                return multiclass;
            }
            if (objective.startsWith("rank:")) {
                return rank;
            }
            return other;
        }
    }
}

