001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package hivemall.xgboost; 020 021import hivemall.UDTFWithOptions; 022import hivemall.annotations.VisibleForTesting; 023import hivemall.utils.collections.lists.FloatArrayList; 024import hivemall.utils.hadoop.HadoopUtils; 025import hivemall.utils.hadoop.HiveUtils; 026import hivemall.utils.lang.OptionUtils; 027import hivemall.utils.math.MathUtils; 028import hivemall.xgboost.utils.DMatrixBuilder; 029import hivemall.xgboost.utils.DenseDMatrixBuilder; 030import hivemall.xgboost.utils.NativeLibLoader; 031import hivemall.xgboost.utils.SparseDMatrixBuilder; 032import hivemall.xgboost.utils.XGBoostUtils; 033import matrix4j.utils.lang.ArrayUtils; 034import matrix4j.utils.lang.Primitives; 035import ml.dmlc.xgboost4j.java.Booster; 036import ml.dmlc.xgboost4j.java.DMatrix; 037import ml.dmlc.xgboost4j.java.XGBoostError; 038 039import java.lang.reflect.InvocationTargetException; 040import java.util.ArrayList; 041import java.util.Arrays; 042import java.util.HashMap; 043import java.util.List; 044import java.util.Map; 045import java.util.Random; 046 047import javax.annotation.Nonnegative; 048import javax.annotation.Nonnull; 049 050import org.apache.commons.cli.CommandLine; 051import org.apache.commons.cli.Options; 052import org.apache.commons.logging.Log; 053import org.apache.commons.logging.LogFactory; 054import org.apache.hadoop.hive.ql.exec.Description; 055import org.apache.hadoop.hive.ql.exec.UDFArgumentException; 056import org.apache.hadoop.hive.ql.metadata.HiveException; 057import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; 058import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; 059import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; 060import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; 061import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; 062import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; 063import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; 064import org.apache.hadoop.io.Text; 065 066/** 067 * UDTF for train_xgboost 068 */ 069//@formatter:off 070@Description(name = "train_xgboost", 071 value = "_FUNC_(array<string|double> features, <int|double> target, const string options)" 072 + " - Returns a relation consists of <string model_id, array<string> pred_model>", 073 extended = "SELECT \n" + 074 " train_xgboost(features, label, '-objective binary:logistic -iters 10') \n" + 075 " as (model_id, model)\n" + 076 "from (\n" + 077 " select features, label\n" + 078 " from xgb_input\n" + 079 " cluster by rand(43) -- shuffle\n" + 080 ") shuffled;") 081//@formatter:on 082public class XGBoostTrainUDTF extends UDTFWithOptions { 083 private static final Log logger = LogFactory.getLog(XGBoostTrainUDTF.class); 084 085 // Settings for the XGBoost native library 086 static { 087 NativeLibLoader.initXGBoost(); 088 } 089 090 // For input parameters 091 private ListObjectInspector featureListOI; 092 private PrimitiveObjectInspector featureElemOI; 093 private PrimitiveObjectInspector targetOI; 094 095 // For training input buffering 096 private boolean denseInput; 097 private DMatrixBuilder matrixBuilder; 098 private FloatArrayList labels; 099 100 // For XGBoost options 101 @Nonnull 102 protected final Map<String, Object> params = new HashMap<String, Object>(); 103 104 protected int numClass; 105 protected ObjectiveType objectiveType = null; 106 107 public enum ObjectiveType { 108 regression, binary, multiclass, rank, other; 109 110 @Nonnull 111 public static ObjectiveType resolve(@Nonnull String objective) { 112 if (objective.startsWith("reg:")) { 113 return regression; 114 } else if (objective.startsWith("binary:")) { 115 return binary; 116 } else if (objective.startsWith("multi:")) { 117 return multiclass; 118 } else if (objective.startsWith("rank:")) { 119 return rank; 120 } else { 121 return other; 122 } 123 } 124 } 125 126 127 public XGBoostTrainUDTF() {} 128 129 @Override 130 protected Options getOptions() { 131 final Options opts = new Options(); 132 133 opts.addOption("num_round", "iters", true, "Number of boosting iterations [default: 10]"); 134 opts.addOption("maximize_evaluation_metrics", true, 135 "Maximize evaluation metrics [default: false]"); 136 opts.addOption("num_early_stopping_rounds", true, 137 "Minimum rounds required for early stopping [default: 0]"); 138 opts.addOption("validation_ratio", true, 139 "Validation ratio in range [0.0,1.0] [default: 0.2]"); 140 141 /** General parameters */ 142 opts.addOption("booster", true, 143 "Set a booster to use, gbtree or gblinear or dart. [default: gbree]"); 144 opts.addOption("silent", true, "Deprecated. Please use verbosity instead. " 145 + "0 means printing running messages, 1 means silent mode [default: 1]"); 146 opts.addOption("verbosity", true, "Verbosity of printing messages. " 147 + "Choices: 0 (silent), 1 (warning), 2 (info), 3 (debug). [default: 0]"); 148 opts.addOption("disable_default_eval_metric", true, 149 "NFlag to disable default metric. Set to >0 to disable. [default: 0]"); 150 opts.addOption("num_pbuffer", true, 151 "Size of prediction buffer [default: set automatically by xgboost]"); 152 opts.addOption("num_feature", true, 153 "Feature dimension used in boosting [default: set automatically by xgboost]"); 154 155 /** Parameters among Boosters */ 156 opts.addOption("lambda", "reg_lambda", true, 157 "L2 regularization term on weights. Increasing this value will make model more conservative." 158 + " [default: 1.0 for gbtree, 0.0 for gblinear]"); 159 opts.addOption("alpha", "reg_alpha", true, 160 "L1 regularization term on weights. Increasing this value will make model more conservative." 161 + " [default: 0.0]"); 162 opts.addOption("updater", true, 163 "A comma-separated string that defines the sequence of tree updaters to run. " 164 + "For a full list of valid inputs, please refer to XGBoost Parameters." 165 + " [default: 'grow_colmaker,prune' for gbtree, 'shotgun' for gblinear]"); 166 167 /** Parameters for Tree Booster */ 168 opts.addOption("eta", "learning_rate", true, 169 "Step size shrinkage used in update to prevents overfitting [default: 0.3]"); 170 opts.addOption("gamma", "min_split_loss", true, 171 "Minimum loss reduction required to make a further partition on a leaf node of the tree." 172 + " [default: 0.0]"); 173 opts.addOption("max_depth", true, "Max depth of decision tree [default: 6]"); 174 opts.addOption("min_child_weight", true, 175 "Minimum sum of instance weight (hessian) needed in a child [default: 1.0]"); 176 opts.addOption("max_delta_step", true, 177 "Maximum delta step we allow each tree's weight estimation to be [default: 0]"); 178 opts.addOption("subsample", true, 179 "Subsample ratio of the training instance in range (0.0,1.0] [default: 1.0]"); 180 opts.addOption("colsample_bytree", true, 181 "Subsample ratio of columns when constructing each tree [default: 1.0]"); 182 opts.addOption("colsample_bylevel", true, 183 "Subsample ratio of columns for each level [default: 1.0]"); 184 opts.addOption("colsample_bynode", true, 185 "Subsample ratio of columns for each node [default: 1.0]"); 186 // tree_method 187 opts.addOption("tree_method", true, 188 "The tree construction algorithm used in XGBoost. [default: auto, Choices: auto, exact, approx, hist]"); 189 opts.addOption("sketch_eps", true, 190 "This roughly translates into O(1 / sketch_eps) number of bins. \n" 191 + "Compared to directly select number of bins, this comes with theoretical guarantee with sketch accuracy.\n" 192 + "Only used for tree_method=approx. Usually user does not have to tune this. [default: 0.03]"); 193 opts.addOption("scale_pos_weight", true, 194 "ontrol the balance of positive and negative weights, useful for unbalanced classes. " 195 + "A typical value to consider: sum(negative instances) / sum(positive instances)" 196 + " [default: 1.0]"); 197 opts.addOption("refresh_leaf", true, 198 "This is a parameter of the refresh updater plugin. " 199 + "When this flag is 1, tree leafs as well as tree nodes’ stats are updated. " 200 + "When it is 0, only node stats are updated. [default: 1]"); 201 opts.addOption("process_type", true, 202 "A type of boosting process to run. [Choices: default, update]"); 203 opts.addOption("grow_policy", true, 204 "Controls a way new nodes are added to the tree. Currently supported only if tree_method is set to hist." 205 + " [default: depthwise, Choices: depthwise, lossguide]"); 206 opts.addOption("max_leaves", true, 207 "Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set. [default: 0]"); 208 opts.addOption("max_bin", true, 209 "Maximum number of discrete bins to bucket continuous features. Only used if tree_method is set to hist." 210 + " [default: 256]"); 211 opts.addOption("num_parallel_tree", true, 212 "Number of parallel trees constructed during each iteration. This option is used to support boosted random forest. " 213 + "Usually no need to tune (default 1 is enough) for gradient boosting trees." 214 + " [default: 1]"); 215 216 /** Parameters for Dart Booster (booster=dart) */ 217 opts.addOption("sample_type", true, 218 "Type of sampling algorithm. [Choices: uniform (default), weighted]"); 219 opts.addOption("normalize_type", true, 220 "Type of normalization algorithm. [Choices: tree (default), forest]"); 221 opts.addOption("rate_drop", true, "Dropout rate in range [0.0, 1.0]. [default: 0.0]"); 222 opts.addOption("one_drop", true, 223 "When this flag is enabled, at least one tree is always dropped during the dropout. " 224 + "0 or 1. [default: 0]"); 225 opts.addOption("skip_drop", true, 226 "Probability of skipping the dropout procedure during a boosting iteration " 227 + "in range [0.0, 1.0]. [default: 0.0]"); 228 229 /** Parameters for Linear Booster (booster=gblinear) */ 230 opts.addOption("lambda_bias", true, "L2 regularization term on bias [default: 0.0]"); 231 opts.addOption("feature_selector", true, "Feature selection and ordering method. " 232 + "[Choices: cyclic (default), shuffle, random, greedy, thrifty]"); 233 opts.addOption("top_k", true, 234 "The number of top features to select in greedy and thrifty feature selector. " 235 + "The value of 0 means using all the features. [default: 0]"); 236 237 /** Parameters for Tweedie Regression (objective=reg:tweedie) */ 238 opts.addOption("tweedie_variance_power", true, 239 "Parameter that controls the variance of the Tweedie distribution in range [1.0, 2.0]." 240 + " [default: 1.5]"); 241 242 /** Learning Task Parameters */ 243 opts.addOption("objective", true, 244 "Specifies the learning task and the corresponding learning objective. " 245 + "Examples: reg:linear, reg:logistic, multi:softmax. " 246 + "For a full list of valid inputs, refer to XGBoost Parameters. " 247 + "[default: reg:linear]"); 248 opts.addOption("base_score", true, 249 "Initial prediction score of all instances, global bias [default: 0.5]"); 250 opts.addOption("eval_metric", true, 251 "Evaluation metrics for validation data. A default metric is assigned according to the objective:\n" 252 + "- rmse: for regression\n" + "- error: for classification\n" 253 + "- map: for ranking\n" 254 + "For a list of valid inputs, see XGBoost Parameters."); 255 opts.addOption("seed", true, "Random number seed. [default: 43]"); 256 opts.addOption("num_class", true, "Number of classes to classify"); 257 258 return opts; 259 } 260 261 @Nonnull 262 @Override 263 protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { 264 final CommandLine cl; 265 if (argOIs.length >= 3) { 266 String rawArgs = HiveUtils.getConstString(argOIs, 2); 267 cl = parseOptions(rawArgs); 268 } else { 269 cl = parseOptions(""); // use default options 270 } 271 272 String objective = cl.getOptionValue("objective"); 273 if (objective == null) { 274 showHelp("Please provide \"-objective XXX\" option in the 3rd argument.\n\n" 275 + "Here is the list of supported objectives: \n" 276 + " - Regression:\n {reg:squarederror, reg:logistic, reg:gamma, reg:tweedie}\n" 277 + " - Binary classification: {binary:logistic, binary:logitraw, binary:hinge}\n" 278 + " - Multiclass classification:\n {multi:softmax, multi:softprob}\n" 279 + " - Ranking:\n {rank:pairwise, rank:ndcg, rank:map}\n" 280 + " - Other:\n {count:poisson, survival:cox}"); 281 } 282 if (objective.equals("reg:squarederror")) { 283 // reg:linear is deprecated synonym of reg:squarederror 284 // however, reg:squarederror is not supported in xgboost-predictor yet 285 // https://github.com/dmlc/xgboost/pull/4267 286 objective = "reg:linear"; 287 } 288 final String booster = cl.getOptionValue("booster", "gbtree"); 289 290 int numRound = Primitives.parseInt(cl.getOptionValue("num_round"), 10); 291 params.put("num_round", numRound); 292 params.put("maximize_evaluation_metrics", 293 Primitives.parseBoolean(cl.getOptionValue("maximize_evaluation_metrics"), false)); 294 params.put("num_early_stopping_rounds", 295 Primitives.parseInt(cl.getOptionValue("num_early_stopping_rounds"), 0)); 296 double validationRatio = 297 Primitives.parseDouble(cl.getOptionValue("validation_ratio"), 0.2d); 298 if (validationRatio < 0.d || validationRatio >= 1.d) { 299 throw new UDFArgumentException("Invalid validation_ratio=" + validationRatio); 300 } 301 params.put("validation_ratio", validationRatio); 302 303 /** General parameters */ 304 params.put("booster", booster); 305 params.put("silent", Primitives.parseInt(cl.getOptionValue("silent"), 1)); 306 params.put("verbosity", Primitives.parseInt(cl.getOptionValue("verbosity"), 0)); 307 params.put("nthread", Primitives.parseInt(cl.getOptionValue("nthread"), 1)); 308 params.put("disable_default_eval_metric", 309 Primitives.parseInt(cl.getOptionValue("disable_default_eval_metric"), 0)); 310 if (cl.hasOption("num_pbuffer")) { 311 params.put("num_pbuffer", Integer.valueOf(cl.getOptionValue("num_pbuffer"))); 312 } 313 if (cl.hasOption("num_feature")) { 314 params.put("num_feature", Integer.valueOf(cl.getOptionValue("num_feature"))); 315 } 316 317 /** Parameters for Tree Booster (booster=gbtree) */ 318 if (booster.equals("gbtree")) { 319 params.put("eta", Primitives.parseDouble(cl.getOptionValue("eta"), 0.3d)); 320 params.put("gamma", Primitives.parseDouble(cl.getOptionValue("gamma"), 0.d)); 321 params.put("max_depth", Primitives.parseInt(cl.getOptionValue("max_depth"), 6)); 322 params.put("min_child_weight", 323 Primitives.parseDouble(cl.getOptionValue("min_child_weight"), 1.d)); 324 params.put("max_delta_step", 325 Primitives.parseDouble(cl.getOptionValue("max_delta_step"), 0.d)); 326 params.put("subsample", Primitives.parseDouble(cl.getOptionValue("subsample"), 1.d)); 327 params.put("colsamle_bytree", 328 Primitives.parseDouble(cl.getOptionValue("colsample_bytree"), 1.d)); 329 params.put("colsamle_bylevel", 330 Primitives.parseDouble(cl.getOptionValue("colsamle_bylevel"), 1.d)); 331 params.put("colsamle_bynode", 332 Primitives.parseDouble(cl.getOptionValue("colsamle_bynode"), 1.d)); 333 params.put("lambda", Primitives.parseDouble(cl.getOptionValue("lambda"), 1.d)); 334 params.put("alpha", Primitives.parseDouble(cl.getOptionValue("alpha"), 0.d)); 335 params.put("tree_method", cl.getOptionValue("tree_method", "auto")); 336 params.put("sketch_eps", 337 Primitives.parseDouble(cl.getOptionValue("sketch_eps"), 0.03d)); 338 params.put("scale_pos_weight", 339 Primitives.parseDouble(cl.getOptionValue("scale_pos_weight"), 1.d)); 340 params.put("updater", cl.getOptionValue("updater", "grow_colmaker,prune")); 341 params.put("refresh_leaf", Primitives.parseInt(cl.getOptionValue("refresh_leaf"), 1)); 342 params.put("process_type", cl.getOptionValue("process_type", "default")); 343 params.put("grow_policy", cl.getOptionValue("grow_policy", "depthwise")); 344 params.put("max_leaves", Primitives.parseInt(cl.getOptionValue("max_leaves"), 0)); 345 params.put("max_bin", Primitives.parseInt(cl.getOptionValue("max_bin"), 256)); 346 params.put("num_parallel_tree", 347 Primitives.parseInt(cl.getOptionValue("num_parallel_tree"), 1)); 348 } 349 350 /** Parameters for Dart Booster (booster=dart) */ 351 if (booster.equals("dart")) { 352 params.put("sample_type", cl.getOptionValue("sample_type", "uniform")); 353 params.put("normalize_type", cl.getOptionValue("normalize_type", "tree")); 354 params.put("rate_drop", Primitives.parseDouble(cl.getOptionValue("rate_drop"), 0.d)); 355 params.put("one_drop", Primitives.parseInt(cl.getOptionValue("one_drop"), 0)); 356 params.put("skip_drop", Primitives.parseDouble(cl.getOptionValue("skip_drop"), 0.d)); 357 } 358 359 /** Parameters for Linear Booster (booster=gblinear) */ 360 if (booster.equals("gblinear")) { 361 params.put("lambda", Primitives.parseDouble(cl.getOptionValue("lambda"), 0.d)); 362 params.put("lambda_bias", 363 Primitives.parseDouble(cl.getOptionValue("lambda_bias"), 0.d)); 364 params.put("alpha", Primitives.parseDouble(cl.getOptionValue("alpha"), 0.d)); 365 params.put("updater", cl.getOptionValue("updater", "shotgun")); 366 params.put("feature_selector", cl.getOptionValue("feature_selector", "cyclic")); 367 params.put("top_k", Primitives.parseInt(cl.getOptionValue("top_k"), 0)); 368 } 369 370 /** Parameters for Tweedie Regression (objective=reg:tweedie) */ 371 if (objective.equals("reg:tweedie")) { 372 params.put("tweedie_variance_power", 373 Primitives.parseDouble(cl.getOptionValue("tweedie_variance_power"), 1.5d)); 374 } 375 376 /** Parameters for Poisson Regression (objective=count:poisson) */ 377 if (objective.equals("count:poisson")) { 378 // max_delta_step is set to 0.7 by default in poisson regression (used to safeguard optimization) 379 params.put("max_delta_step", 380 Primitives.parseDouble(cl.getOptionValue("max_delta_step"), 0.7d)); 381 } 382 383 /** Learning Task Parameters */ 384 params.put("objective", objective); 385 params.put("base_score", Primitives.parseDouble(cl.getOptionValue("base_score"), 0.5d)); 386 if (cl.hasOption("eval_metric")) { 387 params.put("eval_metric", cl.getOptionValue("eval_metric")); 388 } 389 params.put("seed", Primitives.parseLong(cl.getOptionValue("seed"), 43L)); 390 391 if (cl.hasOption("num_class")) { 392 this.numClass = Integer.parseInt(cl.getOptionValue("num_class")); 393 params.put("num_class", numClass); 394 } else { 395 if (objective.startsWith("multi:")) { 396 throw new UDFArgumentException( 397 "-num_class is required for multiclass classification"); 398 } 399 } 400 401 if (logger.isInfoEnabled()) { 402 logger.info("XGboost training hyperparameters: " + params.toString()); 403 } 404 405 this.objectiveType = ObjectiveType.resolve(objective); 406 407 return cl; 408 } 409 410 @Override 411 public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) 412 throws UDFArgumentException { 413 if (argOIs.length != 2 && argOIs.length != 3) { 414 showHelp("Invalid argment length=" + argOIs.length); 415 } 416 processOptions(argOIs); 417 418 ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 0); 419 ObjectInspector elemOI = listOI.getListElementObjectInspector(); 420 this.featureListOI = listOI; 421 if (HiveUtils.isNumberOI(elemOI)) { 422 this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI); 423 this.denseInput = true; 424 this.matrixBuilder = new DenseDMatrixBuilder(8192); 425 } else if (HiveUtils.isStringOI(elemOI)) { 426 this.featureElemOI = HiveUtils.asStringOI(elemOI); 427 this.denseInput = false; 428 this.matrixBuilder = new SparseDMatrixBuilder(8192); 429 } else { 430 throw new UDFArgumentException( 431 "train_xgboost takes array<double> or array<string> for the first argument: " 432 + listOI.getTypeName()); 433 } 434 this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs, 1); 435 this.labels = new FloatArrayList(1024); 436 437 final List<String> fieldNames = new ArrayList<>(2); 438 final List<ObjectInspector> fieldOIs = new ArrayList<>(2); 439 fieldNames.add("model_id"); 440 fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector); 441 fieldNames.add("model"); 442 fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); 443 return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); 444 } 445 446 /** To validate target range, overrides this method */ 447 protected float processTargetValue(final float target) throws HiveException { 448 switch (objectiveType) { 449 case binary: { 450 if (target != -1 && target != 0 && target != 1) { 451 throw new UDFArgumentException( 452 "Invalid label value for classification: " + target); 453 } 454 return target > 0.f ? 1.f : 0.f; 455 } 456 case multiclass: { 457 final int clazz = (int) target; 458 if (clazz != target) { 459 throw new UDFArgumentException( 460 "Invalid target value for class label: " + target); 461 } 462 if (clazz < 0 || clazz >= numClass) { 463 throw new UDFArgumentException("target must be {0.0, ..., " 464 + String.format("%.1f", (numClass - 1.0)) + "}: " + target); 465 } 466 return target; 467 } 468 default: 469 return target; 470 } 471 } 472 473 @Override 474 public void process(@Nonnull Object[] args) throws HiveException { 475 if (args[0] == null) { 476 throw new HiveException("array<double> features was null"); 477 } 478 parseFeatures(args[0], matrixBuilder); 479 480 float target = PrimitiveObjectInspectorUtils.getFloat(args[1], targetOI); 481 labels.add(processTargetValue(target)); 482 } 483 484 private void parseFeatures(@Nonnull final Object argObj, 485 @Nonnull final DMatrixBuilder builder) { 486 if (denseInput) { 487 final int length = featureListOI.getListLength(argObj); 488 for (int i = 0; i < length; i++) { 489 Object o = featureListOI.getListElement(argObj, i); 490 if (o == null) { 491 continue; 492 } 493 float v = PrimitiveObjectInspectorUtils.getFloat(o, featureElemOI); 494 builder.nextColumn(i, v); 495 } 496 } else { 497 final int length = featureListOI.getListLength(argObj); 498 for (int i = 0; i < length; i++) { 499 Object o = featureListOI.getListElement(argObj, i); 500 if (o == null) { 501 continue; 502 } 503 String fv = o.toString(); 504 builder.nextColumn(fv); 505 } 506 } 507 builder.nextRow(); 508 } 509 510 @Override 511 public void close() throws HiveException { 512 DMatrix dmatrix = null; 513 Booster booster = null; 514 try { 515 dmatrix = matrixBuilder.buildMatrix(labels.toArray(true)); 516 this.matrixBuilder = null; 517 this.labels = null; 518 519 final int round = OptionUtils.getInt(params, "num_round"); 520 final int earlyStoppingRounds = OptionUtils.getInt(params, "num_early_stopping_rounds"); 521 if (earlyStoppingRounds > 0) { 522 double validationRatio = OptionUtils.getDouble(params, "validation_ratio"); 523 long seed = OptionUtils.getLong(params, "seed"); 524 525 int numRows = (int) dmatrix.rowNum(); 526 int[] rows = MathUtils.permutation(numRows); 527 ArrayUtils.shuffle(rows, new Random(seed)); 528 529 int numTest = (int) (numRows * validationRatio); 530 DMatrix dtrain = null, dtest = null; 531 try { 532 dtest = dmatrix.slice(Arrays.copyOf(rows, numTest)); 533 dtrain = dmatrix.slice(Arrays.copyOfRange(rows, numTest, rows.length)); 534 booster = train(dtrain, dtest, round, earlyStoppingRounds, params); 535 } finally { 536 XGBoostUtils.close(dtrain); 537 XGBoostUtils.close(dtest); 538 } 539 } else { 540 booster = train(dmatrix, round, params); 541 } 542 onFinishTraining(booster); 543 544 // Output the built model 545 String modelId = generateUniqueModelId(); 546 Text predModel = XGBoostUtils.serializeBooster(booster); 547 548 logger.info("model_id:" + modelId.toString() + ", size:" + predModel.getLength()); 549 forward(new Object[] {modelId, predModel}); 550 } catch (Throwable e) { 551 throw new HiveException(e); 552 } finally { 553 XGBoostUtils.close(dmatrix); 554 XGBoostUtils.close(booster); 555 } 556 } 557 558 @VisibleForTesting 559 protected void onFinishTraining(@Nonnull Booster booster) {} 560 561 @Nonnull 562 private static Booster train(@Nonnull final DMatrix dtrain, @Nonnegative final int round, 563 @Nonnull final Map<String, Object> params) 564 throws NoSuchMethodException, IllegalAccessException, InvocationTargetException, 565 InstantiationException, XGBoostError { 566 final Booster booster = XGBoostUtils.createBooster(dtrain, params); 567 for (int iter = 0; iter < round; iter++) { 568 booster.update(dtrain, iter); 569 } 570 return booster; 571 } 572 573 @Nonnull 574 private static Booster train(@Nonnull final DMatrix dtrain, @Nonnull final DMatrix dtest, 575 @Nonnegative final int round, @Nonnegative final int earlyStoppingRounds, 576 @Nonnull final Map<String, Object> params) 577 throws NoSuchMethodException, IllegalAccessException, InvocationTargetException, 578 InstantiationException, XGBoostError { 579 final Booster booster = XGBoostUtils.createBooster(dtrain, params); 580 581 final boolean maximizeEvaluationMetrics = 582 OptionUtils.getBoolean(params, "maximize_evaluation_metrics"); 583 float bestScore = maximizeEvaluationMetrics ? -Float.MAX_VALUE : Float.MAX_VALUE; 584 int bestIteration = 0; 585 586 final float[] metricsOut = new float[1]; 587 for (int iter = 0; iter < round; iter++) { 588 booster.update(dtrain, iter); 589 590 String evalInfo = 591 booster.evalSet(new DMatrix[] {dtest}, new String[] {"test"}, iter, metricsOut); 592 logger.info(evalInfo); 593 594 final float score = metricsOut[0]; 595 if (maximizeEvaluationMetrics) { 596 // Update best score if the current score is better (no update when equal) 597 if (score > bestScore) { 598 bestScore = score; 599 bestIteration = iter; 600 } 601 } else { 602 if (score < bestScore) { 603 bestScore = score; 604 bestIteration = iter; 605 } 606 } 607 608 if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) { 609 logger.info( 610 String.format("early stopping after %d rounds away from the best iteration", 611 earlyStoppingRounds)); 612 break; 613 } 614 } 615 616 return booster; 617 } 618 619 private static boolean shouldEarlyStop(final int earlyStoppingRounds, final int iter, 620 final int bestIteration) { 621 return iter - bestIteration >= earlyStoppingRounds; 622 } 623 624 @Nonnull 625 private static String generateUniqueModelId() { 626 return "xgbmodel-" + HadoopUtils.getUniqueTaskIdString(); 627 } 628 629}