001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package hivemall.xgboost;
020
021import hivemall.UDTFWithOptions;
022import hivemall.annotations.VisibleForTesting;
023import hivemall.utils.collections.lists.FloatArrayList;
024import hivemall.utils.hadoop.HadoopUtils;
025import hivemall.utils.hadoop.HiveUtils;
026import hivemall.utils.lang.OptionUtils;
027import hivemall.utils.math.MathUtils;
028import hivemall.xgboost.utils.DMatrixBuilder;
029import hivemall.xgboost.utils.DenseDMatrixBuilder;
030import hivemall.xgboost.utils.NativeLibLoader;
031import hivemall.xgboost.utils.SparseDMatrixBuilder;
032import hivemall.xgboost.utils.XGBoostUtils;
033import matrix4j.utils.lang.ArrayUtils;
034import matrix4j.utils.lang.Primitives;
035import ml.dmlc.xgboost4j.java.Booster;
036import ml.dmlc.xgboost4j.java.DMatrix;
037import ml.dmlc.xgboost4j.java.XGBoostError;
038
039import java.lang.reflect.InvocationTargetException;
040import java.util.ArrayList;
041import java.util.Arrays;
042import java.util.HashMap;
043import java.util.List;
044import java.util.Map;
045import java.util.Random;
046
047import javax.annotation.Nonnegative;
048import javax.annotation.Nonnull;
049
050import org.apache.commons.cli.CommandLine;
051import org.apache.commons.cli.Options;
052import org.apache.commons.logging.Log;
053import org.apache.commons.logging.LogFactory;
054import org.apache.hadoop.hive.ql.exec.Description;
055import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
056import org.apache.hadoop.hive.ql.metadata.HiveException;
057import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
058import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
059import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
060import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
061import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
062import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
063import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
064import org.apache.hadoop.io.Text;
065
066/**
067 * UDTF for train_xgboost
068 */
069//@formatter:off
070@Description(name = "train_xgboost",
071        value = "_FUNC_(array<string|double> features, <int|double> target, const string options)"
072                + " - Returns a relation consists of <string model_id, array<string> pred_model>",
073        extended = "SELECT \n" + 
074                "  train_xgboost(features, label, '-objective binary:logistic -iters 10') \n" + 
075                "    as (model_id, model)\n" + 
076                "from (\n" + 
077                "  select features, label\n" + 
078                "  from xgb_input\n" + 
079                "  cluster by rand(43) -- shuffle\n" + 
080                ") shuffled;")
081//@formatter:on
082public class XGBoostTrainUDTF extends UDTFWithOptions {
083    private static final Log logger = LogFactory.getLog(XGBoostTrainUDTF.class);
084
085    // Settings for the XGBoost native library
086    static {
087        NativeLibLoader.initXGBoost();
088    }
089
090    // For input parameters
091    private ListObjectInspector featureListOI;
092    private PrimitiveObjectInspector featureElemOI;
093    private PrimitiveObjectInspector targetOI;
094
095    // For training input buffering
096    private boolean denseInput;
097    private DMatrixBuilder matrixBuilder;
098    private FloatArrayList labels;
099
100    // For XGBoost options
101    @Nonnull
102    protected final Map<String, Object> params = new HashMap<String, Object>();
103
104    protected int numClass;
105    protected ObjectiveType objectiveType = null;
106
107    public enum ObjectiveType {
108        regression, binary, multiclass, rank, other;
109
110        @Nonnull
111        public static ObjectiveType resolve(@Nonnull String objective) {
112            if (objective.startsWith("reg:")) {
113                return regression;
114            } else if (objective.startsWith("binary:")) {
115                return binary;
116            } else if (objective.startsWith("multi:")) {
117                return multiclass;
118            } else if (objective.startsWith("rank:")) {
119                return rank;
120            } else {
121                return other;
122            }
123        }
124    }
125
126
127    public XGBoostTrainUDTF() {}
128
129    @Override
130    protected Options getOptions() {
131        final Options opts = new Options();
132
133        opts.addOption("num_round", "iters", true, "Number of boosting iterations [default: 10]");
134        opts.addOption("maximize_evaluation_metrics", true,
135            "Maximize evaluation metrics [default: false]");
136        opts.addOption("num_early_stopping_rounds", true,
137            "Minimum rounds required for early stopping [default: 0]");
138        opts.addOption("validation_ratio", true,
139            "Validation ratio in range [0.0,1.0] [default: 0.2]");
140
141        /** General parameters */
142        opts.addOption("booster", true,
143            "Set a booster to use, gbtree or gblinear or dart. [default: gbree]");
144        opts.addOption("silent", true, "Deprecated. Please use verbosity instead. "
145                + "0 means printing running messages, 1 means silent mode [default: 1]");
146        opts.addOption("verbosity", true, "Verbosity of printing messages. "
147                + "Choices: 0 (silent), 1 (warning), 2 (info), 3 (debug). [default: 0]");
148        opts.addOption("disable_default_eval_metric", true,
149            "NFlag to disable default metric. Set to >0 to disable. [default: 0]");
150        opts.addOption("num_pbuffer", true,
151            "Size of prediction buffer [default: set automatically by xgboost]");
152        opts.addOption("num_feature", true,
153            "Feature dimension used in boosting [default: set automatically by xgboost]");
154
155        /** Parameters among Boosters */
156        opts.addOption("lambda", "reg_lambda", true,
157            "L2 regularization term on weights. Increasing this value will make model more conservative."
158                    + " [default: 1.0 for gbtree, 0.0 for gblinear]");
159        opts.addOption("alpha", "reg_alpha", true,
160            "L1 regularization term on weights. Increasing this value will make model more conservative."
161                    + " [default: 0.0]");
162        opts.addOption("updater", true,
163            "A comma-separated string that defines the sequence of tree updaters to run. "
164                    + "For a full list of valid inputs, please refer to XGBoost Parameters."
165                    + " [default: 'grow_colmaker,prune' for gbtree, 'shotgun' for gblinear]");
166
167        /** Parameters for Tree Booster */
168        opts.addOption("eta", "learning_rate", true,
169            "Step size shrinkage used in update to prevents overfitting [default: 0.3]");
170        opts.addOption("gamma", "min_split_loss", true,
171            "Minimum loss reduction required to make a further partition on a leaf node of the tree."
172                    + " [default: 0.0]");
173        opts.addOption("max_depth", true, "Max depth of decision tree [default: 6]");
174        opts.addOption("min_child_weight", true,
175            "Minimum sum of instance weight (hessian) needed in a child [default: 1.0]");
176        opts.addOption("max_delta_step", true,
177            "Maximum delta step we allow each tree's weight estimation to be [default: 0]");
178        opts.addOption("subsample", true,
179            "Subsample ratio of the training instance in range (0.0,1.0] [default: 1.0]");
180        opts.addOption("colsample_bytree", true,
181            "Subsample ratio of columns when constructing each tree [default: 1.0]");
182        opts.addOption("colsample_bylevel", true,
183            "Subsample ratio of columns for each level [default: 1.0]");
184        opts.addOption("colsample_bynode", true,
185            "Subsample ratio of columns for each node [default: 1.0]");
186        // tree_method
187        opts.addOption("tree_method", true,
188            "The tree construction algorithm used in XGBoost. [default: auto, Choices: auto, exact, approx, hist]");
189        opts.addOption("sketch_eps", true,
190            "This roughly translates into O(1 / sketch_eps) number of bins. \n"
191                    + "Compared to directly select number of bins, this comes with theoretical guarantee with sketch accuracy.\n"
192                    + "Only used for tree_method=approx. Usually user does not have to tune this.  [default: 0.03]");
193        opts.addOption("scale_pos_weight", true,
194            "ontrol the balance of positive and negative weights, useful for unbalanced classes. "
195                    + "A typical value to consider: sum(negative instances) / sum(positive instances)"
196                    + " [default: 1.0]");
197        opts.addOption("refresh_leaf", true,
198            "This is a parameter of the refresh updater plugin. "
199                    + "When this flag is 1, tree leafs as well as tree nodes’ stats are updated. "
200                    + "When it is 0, only node stats are updated. [default: 1]");
201        opts.addOption("process_type", true,
202            "A type of boosting process to run. [Choices: default, update]");
203        opts.addOption("grow_policy", true,
204            "Controls a way new nodes are added to the tree. Currently supported only if tree_method is set to hist."
205                    + " [default: depthwise, Choices: depthwise, lossguide]");
206        opts.addOption("max_leaves", true,
207            "Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set. [default: 0]");
208        opts.addOption("max_bin", true,
209            "Maximum number of discrete bins to bucket continuous features. Only used if tree_method is set to hist."
210                    + " [default: 256]");
211        opts.addOption("num_parallel_tree", true,
212            "Number of parallel trees constructed during each iteration. This option is used to support boosted random forest. "
213                    + "Usually no need to tune (default 1 is enough) for gradient boosting trees."
214                    + " [default: 1]");
215
216        /** Parameters for Dart Booster (booster=dart) */
217        opts.addOption("sample_type", true,
218            "Type of sampling algorithm. [Choices: uniform (default), weighted]");
219        opts.addOption("normalize_type", true,
220            "Type of normalization algorithm. [Choices: tree (default), forest]");
221        opts.addOption("rate_drop", true, "Dropout rate in range [0.0, 1.0]. [default: 0.0]");
222        opts.addOption("one_drop", true,
223            "When this flag is enabled, at least one tree is always dropped during the dropout. "
224                    + "0 or 1. [default: 0]");
225        opts.addOption("skip_drop", true,
226            "Probability of skipping the dropout procedure during a boosting iteration "
227                    + "in range [0.0, 1.0]. [default: 0.0]");
228
229        /** Parameters for Linear Booster (booster=gblinear) */
230        opts.addOption("lambda_bias", true, "L2 regularization term on bias [default: 0.0]");
231        opts.addOption("feature_selector", true, "Feature selection and ordering method. "
232                + "[Choices: cyclic (default), shuffle, random, greedy, thrifty]");
233        opts.addOption("top_k", true,
234            "The number of top features to select in greedy and thrifty feature selector. "
235                    + "The value of 0 means using all the features. [default: 0]");
236
237        /** Parameters for Tweedie Regression (objective=reg:tweedie) */
238        opts.addOption("tweedie_variance_power", true,
239            "Parameter that controls the variance of the Tweedie distribution in range [1.0, 2.0]."
240                    + " [default: 1.5]");
241
242        /** Learning Task Parameters */
243        opts.addOption("objective", true,
244            "Specifies the learning task and the corresponding learning objective. "
245                    + "Examples: reg:linear, reg:logistic, multi:softmax. "
246                    + "For a full list of valid inputs, refer to XGBoost Parameters. "
247                    + "[default: reg:linear]");
248        opts.addOption("base_score", true,
249            "Initial prediction score of all instances, global bias [default: 0.5]");
250        opts.addOption("eval_metric", true,
251            "Evaluation metrics for validation data. A default metric is assigned according to the objective:\n"
252                    + "- rmse: for regression\n" + "- error: for classification\n"
253                    + "- map: for ranking\n"
254                    + "For a list of valid inputs, see XGBoost Parameters.");
255        opts.addOption("seed", true, "Random number seed. [default: 43]");
256        opts.addOption("num_class", true, "Number of classes to classify");
257
258        return opts;
259    }
260
261    @Nonnull
262    @Override
263    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
264        final CommandLine cl;
265        if (argOIs.length >= 3) {
266            String rawArgs = HiveUtils.getConstString(argOIs, 2);
267            cl = parseOptions(rawArgs);
268        } else {
269            cl = parseOptions(""); // use default options
270        }
271
272        String objective = cl.getOptionValue("objective");
273        if (objective == null) {
274            showHelp("Please provide \"-objective XXX\" option in the 3rd argument.\n\n"
275                    + "Here is the list of supported objectives: \n"
276                    + " - Regression:\n {reg:squarederror, reg:logistic, reg:gamma, reg:tweedie}\n"
277                    + " - Binary classification: {binary:logistic, binary:logitraw, binary:hinge}\n"
278                    + " - Multiclass classification:\n {multi:softmax, multi:softprob}\n"
279                    + " - Ranking:\n {rank:pairwise, rank:ndcg, rank:map}\n"
280                    + " - Other:\n {count:poisson, survival:cox}");
281        }
282        if (objective.equals("reg:squarederror")) {
283            // reg:linear is deprecated synonym of reg:squarederror
284            // however, reg:squarederror is not supported in xgboost-predictor yet
285            // https://github.com/dmlc/xgboost/pull/4267
286            objective = "reg:linear";
287        }
288        final String booster = cl.getOptionValue("booster", "gbtree");
289
290        int numRound = Primitives.parseInt(cl.getOptionValue("num_round"), 10);
291        params.put("num_round", numRound);
292        params.put("maximize_evaluation_metrics",
293            Primitives.parseBoolean(cl.getOptionValue("maximize_evaluation_metrics"), false));
294        params.put("num_early_stopping_rounds",
295            Primitives.parseInt(cl.getOptionValue("num_early_stopping_rounds"), 0));
296        double validationRatio =
297                Primitives.parseDouble(cl.getOptionValue("validation_ratio"), 0.2d);
298        if (validationRatio < 0.d || validationRatio >= 1.d) {
299            throw new UDFArgumentException("Invalid validation_ratio=" + validationRatio);
300        }
301        params.put("validation_ratio", validationRatio);
302
303        /** General parameters */
304        params.put("booster", booster);
305        params.put("silent", Primitives.parseInt(cl.getOptionValue("silent"), 1));
306        params.put("verbosity", Primitives.parseInt(cl.getOptionValue("verbosity"), 0));
307        params.put("nthread", Primitives.parseInt(cl.getOptionValue("nthread"), 1));
308        params.put("disable_default_eval_metric",
309            Primitives.parseInt(cl.getOptionValue("disable_default_eval_metric"), 0));
310        if (cl.hasOption("num_pbuffer")) {
311            params.put("num_pbuffer", Integer.valueOf(cl.getOptionValue("num_pbuffer")));
312        }
313        if (cl.hasOption("num_feature")) {
314            params.put("num_feature", Integer.valueOf(cl.getOptionValue("num_feature")));
315        }
316
317        /** Parameters for Tree Booster (booster=gbtree) */
318        if (booster.equals("gbtree")) {
319            params.put("eta", Primitives.parseDouble(cl.getOptionValue("eta"), 0.3d));
320            params.put("gamma", Primitives.parseDouble(cl.getOptionValue("gamma"), 0.d));
321            params.put("max_depth", Primitives.parseInt(cl.getOptionValue("max_depth"), 6));
322            params.put("min_child_weight",
323                Primitives.parseDouble(cl.getOptionValue("min_child_weight"), 1.d));
324            params.put("max_delta_step",
325                Primitives.parseDouble(cl.getOptionValue("max_delta_step"), 0.d));
326            params.put("subsample", Primitives.parseDouble(cl.getOptionValue("subsample"), 1.d));
327            params.put("colsamle_bytree",
328                Primitives.parseDouble(cl.getOptionValue("colsample_bytree"), 1.d));
329            params.put("colsamle_bylevel",
330                Primitives.parseDouble(cl.getOptionValue("colsamle_bylevel"), 1.d));
331            params.put("colsamle_bynode",
332                Primitives.parseDouble(cl.getOptionValue("colsamle_bynode"), 1.d));
333            params.put("lambda", Primitives.parseDouble(cl.getOptionValue("lambda"), 1.d));
334            params.put("alpha", Primitives.parseDouble(cl.getOptionValue("alpha"), 0.d));
335            params.put("tree_method", cl.getOptionValue("tree_method", "auto"));
336            params.put("sketch_eps",
337                Primitives.parseDouble(cl.getOptionValue("sketch_eps"), 0.03d));
338            params.put("scale_pos_weight",
339                Primitives.parseDouble(cl.getOptionValue("scale_pos_weight"), 1.d));
340            params.put("updater", cl.getOptionValue("updater", "grow_colmaker,prune"));
341            params.put("refresh_leaf", Primitives.parseInt(cl.getOptionValue("refresh_leaf"), 1));
342            params.put("process_type", cl.getOptionValue("process_type", "default"));
343            params.put("grow_policy", cl.getOptionValue("grow_policy", "depthwise"));
344            params.put("max_leaves", Primitives.parseInt(cl.getOptionValue("max_leaves"), 0));
345            params.put("max_bin", Primitives.parseInt(cl.getOptionValue("max_bin"), 256));
346            params.put("num_parallel_tree",
347                Primitives.parseInt(cl.getOptionValue("num_parallel_tree"), 1));
348        }
349
350        /** Parameters for Dart Booster (booster=dart) */
351        if (booster.equals("dart")) {
352            params.put("sample_type", cl.getOptionValue("sample_type", "uniform"));
353            params.put("normalize_type", cl.getOptionValue("normalize_type", "tree"));
354            params.put("rate_drop", Primitives.parseDouble(cl.getOptionValue("rate_drop"), 0.d));
355            params.put("one_drop", Primitives.parseInt(cl.getOptionValue("one_drop"), 0));
356            params.put("skip_drop", Primitives.parseDouble(cl.getOptionValue("skip_drop"), 0.d));
357        }
358
359        /** Parameters for Linear Booster (booster=gblinear) */
360        if (booster.equals("gblinear")) {
361            params.put("lambda", Primitives.parseDouble(cl.getOptionValue("lambda"), 0.d));
362            params.put("lambda_bias",
363                Primitives.parseDouble(cl.getOptionValue("lambda_bias"), 0.d));
364            params.put("alpha", Primitives.parseDouble(cl.getOptionValue("alpha"), 0.d));
365            params.put("updater", cl.getOptionValue("updater", "shotgun"));
366            params.put("feature_selector", cl.getOptionValue("feature_selector", "cyclic"));
367            params.put("top_k", Primitives.parseInt(cl.getOptionValue("top_k"), 0));
368        }
369
370        /** Parameters for Tweedie Regression (objective=reg:tweedie) */
371        if (objective.equals("reg:tweedie")) {
372            params.put("tweedie_variance_power",
373                Primitives.parseDouble(cl.getOptionValue("tweedie_variance_power"), 1.5d));
374        }
375
376        /** Parameters for Poisson Regression (objective=count:poisson) */
377        if (objective.equals("count:poisson")) {
378            // max_delta_step is set to 0.7 by default in poisson regression (used to safeguard optimization)
379            params.put("max_delta_step",
380                Primitives.parseDouble(cl.getOptionValue("max_delta_step"), 0.7d));
381        }
382
383        /** Learning Task Parameters */
384        params.put("objective", objective);
385        params.put("base_score", Primitives.parseDouble(cl.getOptionValue("base_score"), 0.5d));
386        if (cl.hasOption("eval_metric")) {
387            params.put("eval_metric", cl.getOptionValue("eval_metric"));
388        }
389        params.put("seed", Primitives.parseLong(cl.getOptionValue("seed"), 43L));
390
391        if (cl.hasOption("num_class")) {
392            this.numClass = Integer.parseInt(cl.getOptionValue("num_class"));
393            params.put("num_class", numClass);
394        } else {
395            if (objective.startsWith("multi:")) {
396                throw new UDFArgumentException(
397                    "-num_class is required for multiclass classification");
398            }
399        }
400
401        if (logger.isInfoEnabled()) {
402            logger.info("XGboost training hyperparameters: " + params.toString());
403        }
404
405        this.objectiveType = ObjectiveType.resolve(objective);
406
407        return cl;
408    }
409
410    @Override
411    public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs)
412            throws UDFArgumentException {
413        if (argOIs.length != 2 && argOIs.length != 3) {
414            showHelp("Invalid argment length=" + argOIs.length);
415        }
416        processOptions(argOIs);
417
418        ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 0);
419        ObjectInspector elemOI = listOI.getListElementObjectInspector();
420        this.featureListOI = listOI;
421        if (HiveUtils.isNumberOI(elemOI)) {
422            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
423            this.denseInput = true;
424            this.matrixBuilder = new DenseDMatrixBuilder(8192);
425        } else if (HiveUtils.isStringOI(elemOI)) {
426            this.featureElemOI = HiveUtils.asStringOI(elemOI);
427            this.denseInput = false;
428            this.matrixBuilder = new SparseDMatrixBuilder(8192);
429        } else {
430            throw new UDFArgumentException(
431                "train_xgboost takes array<double> or array<string> for the first argument: "
432                        + listOI.getTypeName());
433        }
434        this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs, 1);
435        this.labels = new FloatArrayList(1024);
436
437        final List<String> fieldNames = new ArrayList<>(2);
438        final List<ObjectInspector> fieldOIs = new ArrayList<>(2);
439        fieldNames.add("model_id");
440        fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
441        fieldNames.add("model");
442        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
443        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
444    }
445
446    /** To validate target range, overrides this method */
447    protected float processTargetValue(final float target) throws HiveException {
448        switch (objectiveType) {
449            case binary: {
450                if (target != -1 && target != 0 && target != 1) {
451                    throw new UDFArgumentException(
452                        "Invalid label value for classification: " + target);
453                }
454                return target > 0.f ? 1.f : 0.f;
455            }
456            case multiclass: {
457                final int clazz = (int) target;
458                if (clazz != target) {
459                    throw new UDFArgumentException(
460                        "Invalid target value for class label: " + target);
461                }
462                if (clazz < 0 || clazz >= numClass) {
463                    throw new UDFArgumentException("target must be {0.0, ..., "
464                            + String.format("%.1f", (numClass - 1.0)) + "}: " + target);
465                }
466                return target;
467            }
468            default:
469                return target;
470        }
471    }
472
473    @Override
474    public void process(@Nonnull Object[] args) throws HiveException {
475        if (args[0] == null) {
476            throw new HiveException("array<double> features was null");
477        }
478        parseFeatures(args[0], matrixBuilder);
479
480        float target = PrimitiveObjectInspectorUtils.getFloat(args[1], targetOI);
481        labels.add(processTargetValue(target));
482    }
483
484    private void parseFeatures(@Nonnull final Object argObj,
485            @Nonnull final DMatrixBuilder builder) {
486        if (denseInput) {
487            final int length = featureListOI.getListLength(argObj);
488            for (int i = 0; i < length; i++) {
489                Object o = featureListOI.getListElement(argObj, i);
490                if (o == null) {
491                    continue;
492                }
493                float v = PrimitiveObjectInspectorUtils.getFloat(o, featureElemOI);
494                builder.nextColumn(i, v);
495            }
496        } else {
497            final int length = featureListOI.getListLength(argObj);
498            for (int i = 0; i < length; i++) {
499                Object o = featureListOI.getListElement(argObj, i);
500                if (o == null) {
501                    continue;
502                }
503                String fv = o.toString();
504                builder.nextColumn(fv);
505            }
506        }
507        builder.nextRow();
508    }
509
510    @Override
511    public void close() throws HiveException {
512        DMatrix dmatrix = null;
513        Booster booster = null;
514        try {
515            dmatrix = matrixBuilder.buildMatrix(labels.toArray(true));
516            this.matrixBuilder = null;
517            this.labels = null;
518
519            final int round = OptionUtils.getInt(params, "num_round");
520            final int earlyStoppingRounds = OptionUtils.getInt(params, "num_early_stopping_rounds");
521            if (earlyStoppingRounds > 0) {
522                double validationRatio = OptionUtils.getDouble(params, "validation_ratio");
523                long seed = OptionUtils.getLong(params, "seed");
524
525                int numRows = (int) dmatrix.rowNum();
526                int[] rows = MathUtils.permutation(numRows);
527                ArrayUtils.shuffle(rows, new Random(seed));
528
529                int numTest = (int) (numRows * validationRatio);
530                DMatrix dtrain = null, dtest = null;
531                try {
532                    dtest = dmatrix.slice(Arrays.copyOf(rows, numTest));
533                    dtrain = dmatrix.slice(Arrays.copyOfRange(rows, numTest, rows.length));
534                    booster = train(dtrain, dtest, round, earlyStoppingRounds, params);
535                } finally {
536                    XGBoostUtils.close(dtrain);
537                    XGBoostUtils.close(dtest);
538                }
539            } else {
540                booster = train(dmatrix, round, params);
541            }
542            onFinishTraining(booster);
543
544            // Output the built model
545            String modelId = generateUniqueModelId();
546            Text predModel = XGBoostUtils.serializeBooster(booster);
547
548            logger.info("model_id:" + modelId.toString() + ", size:" + predModel.getLength());
549            forward(new Object[] {modelId, predModel});
550        } catch (Throwable e) {
551            throw new HiveException(e);
552        } finally {
553            XGBoostUtils.close(dmatrix);
554            XGBoostUtils.close(booster);
555        }
556    }
557
558    @VisibleForTesting
559    protected void onFinishTraining(@Nonnull Booster booster) {}
560
561    @Nonnull
562    private static Booster train(@Nonnull final DMatrix dtrain, @Nonnegative final int round,
563            @Nonnull final Map<String, Object> params)
564            throws NoSuchMethodException, IllegalAccessException, InvocationTargetException,
565            InstantiationException, XGBoostError {
566        final Booster booster = XGBoostUtils.createBooster(dtrain, params);
567        for (int iter = 0; iter < round; iter++) {
568            booster.update(dtrain, iter);
569        }
570        return booster;
571    }
572
573    @Nonnull
574    private static Booster train(@Nonnull final DMatrix dtrain, @Nonnull final DMatrix dtest,
575            @Nonnegative final int round, @Nonnegative final int earlyStoppingRounds,
576            @Nonnull final Map<String, Object> params)
577            throws NoSuchMethodException, IllegalAccessException, InvocationTargetException,
578            InstantiationException, XGBoostError {
579        final Booster booster = XGBoostUtils.createBooster(dtrain, params);
580
581        final boolean maximizeEvaluationMetrics =
582                OptionUtils.getBoolean(params, "maximize_evaluation_metrics");
583        float bestScore = maximizeEvaluationMetrics ? -Float.MAX_VALUE : Float.MAX_VALUE;
584        int bestIteration = 0;
585
586        final float[] metricsOut = new float[1];
587        for (int iter = 0; iter < round; iter++) {
588            booster.update(dtrain, iter);
589
590            String evalInfo =
591                    booster.evalSet(new DMatrix[] {dtest}, new String[] {"test"}, iter, metricsOut);
592            logger.info(evalInfo);
593
594            final float score = metricsOut[0];
595            if (maximizeEvaluationMetrics) {
596                // Update best score if the current score is better (no update when equal)
597                if (score > bestScore) {
598                    bestScore = score;
599                    bestIteration = iter;
600                }
601            } else {
602                if (score < bestScore) {
603                    bestScore = score;
604                    bestIteration = iter;
605                }
606            }
607
608            if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
609                logger.info(
610                    String.format("early stopping after %d rounds away from the best iteration",
611                        earlyStoppingRounds));
612                break;
613            }
614        }
615
616        return booster;
617    }
618
619    private static boolean shouldEarlyStop(final int earlyStoppingRounds, final int iter,
620            final int bestIteration) {
621        return iter - bestIteration >= earlyStoppingRounds;
622    }
623
624    @Nonnull
625    private static String generateUniqueModelId() {
626        return "xgbmodel-" + HadoopUtils.getUniqueTaskIdString();
627    }
628
629}