001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package hivemall.xgboost; 020 021import hivemall.UDTFWithOptions; 022import hivemall.utils.collections.lists.FloatArrayList; 023import hivemall.utils.collections.lists.IntArrayList; 024import hivemall.utils.hadoop.HiveUtils; 025import hivemall.utils.hadoop.WritableUtils; 026import hivemall.utils.lang.Primitives; 027import hivemall.xgboost.utils.NativeLibLoader; 028import hivemall.xgboost.utils.XGBoostUtils; 029import ml.dmlc.xgboost4j.LabeledPoint; 030import ml.dmlc.xgboost4j.java.Booster; 031import ml.dmlc.xgboost4j.java.DMatrix; 032import ml.dmlc.xgboost4j.java.XGBoostError; 033 034import java.util.ArrayList; 035import java.util.HashMap; 036import java.util.List; 037import java.util.Map; 038import java.util.Map.Entry; 039import java.util.Objects; 040 041import javax.annotation.Nonnull; 042import javax.annotation.Nullable; 043 044import org.apache.commons.cli.CommandLine; 045import org.apache.commons.cli.Options; 046import org.apache.hadoop.hive.ql.exec.Description; 047import org.apache.hadoop.hive.ql.exec.UDFArgumentException; 048import org.apache.hadoop.hive.ql.metadata.HiveException; 049import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; 050import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; 051import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; 052import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; 053import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; 054import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; 055import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; 056import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; 057import org.apache.hadoop.io.FloatWritable; 058import org.apache.hadoop.io.Text; 059import org.apache.hadoop.io.Writable; 060 061//@formatter:off 062@Description(name = "xgboost_batch_predict", 063 value = "_FUNC_(PRIMITIVE rowid, array<string|double> features, string model_id, array<string> pred_model [, string options]) " 064 + "- Returns a prediction result as (string rowid, array<double> predicted)", 065 extended = "select\n" + 066 " rowid, \n" + 067 " array_avg(predicted) as predicted,\n" + 068 " avg(predicted[0]) as predicted0\n" + 069 "from (\n" + 070 " select\n" + 071 " xgboost_batch_predict(rowid, features, model_id, model) as (rowid, predicted)\n" + 072 " from\n" + 073 " xgb_model l\n" + 074 " LEFT OUTER JOIN xgb_input r\n" + 075 ") t\n" + 076 "group by rowid;") 077//@formatter:on 078public final class XGBoostBatchPredictUDTF extends UDTFWithOptions { 079 080 // For input parameters 081 private PrimitiveObjectInspector rowIdOI; 082 private ListObjectInspector featureListOI; 083 private boolean denseFeatures; 084 @Nullable 085 private PrimitiveObjectInspector featureElemOI; 086 private StringObjectInspector modelIdOI; 087 private StringObjectInspector modelOI; 088 089 // For input buffer 090 private transient Map<String, Booster> mapToModel; 091 private transient Map<String, List<LabeledPointWithRowId>> rowBuffer; 092 093 private int _batchSize; 094 095 @Nonnull 096 protected transient final Object[] _forwardObj; 097 098 // Settings for the XGBoost native library 099 static { 100 NativeLibLoader.initXGBoost(); 101 } 102 103 public XGBoostBatchPredictUDTF() { 104 super(); 105 this._forwardObj = new Object[2]; 106 } 107 108 @Override 109 protected Options getOptions() { 110 Options opts = new Options(); 111 opts.addOption("batch_size", true, "Number of rows to predict together [default: 128]"); 112 return opts; 113 } 114 115 @Override 116 protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { 117 int batchSize = 128; 118 CommandLine cl = null; 119 if (argOIs.length >= 5) { 120 String rawArgs = HiveUtils.getConstString(argOIs, 4); 121 cl = parseOptions(rawArgs); 122 batchSize = Primitives.parseInt(cl.getOptionValue("batch_size"), batchSize); 123 if (batchSize < 1) { 124 throw new UDFArgumentException("batch_size must be greater than 0: " + batchSize); 125 } 126 } 127 this._batchSize = batchSize; 128 return cl; 129 } 130 131 @Override 132 public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) 133 throws UDFArgumentException { 134 if (argOIs.length != 4 && argOIs.length != 5) { 135 showHelp("Invalid argment length=" + argOIs.length); 136 } 137 processOptions(argOIs); 138 139 this.rowIdOI = HiveUtils.asPrimitiveObjectInspector(argOIs, 0); 140 141 this.featureListOI = HiveUtils.asListOI(argOIs, 1); 142 ObjectInspector elemOI = featureListOI.getListElementObjectInspector(); 143 if (HiveUtils.isNumberOI(elemOI)) { 144 this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI); 145 this.denseFeatures = true; 146 } else if (HiveUtils.isStringOI(elemOI)) { 147 this.denseFeatures = false; 148 } else { 149 throw new UDFArgumentException( 150 "Expected array<string|double> for the 2nd argment but got an unexpected features type: " 151 + featureListOI.getTypeName()); 152 } 153 this.modelIdOI = HiveUtils.asStringOI(argOIs, 2); 154 this.modelOI = HiveUtils.asStringOI(argOIs, 3); 155 156 return getReturnOI(rowIdOI); 157 } 158 159 /** Override this to output predicted results depending on a task type */ 160 /** Return (string rowid, array<double> predicted) as a result */ 161 @Nonnull 162 protected StructObjectInspector getReturnOI(@Nonnull PrimitiveObjectInspector rowIdOI) { 163 List<String> fieldNames = new ArrayList<>(2); 164 List<ObjectInspector> fieldOIs = new ArrayList<>(2); 165 fieldNames.add("rowid"); 166 fieldOIs.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( 167 rowIdOI.getPrimitiveCategory())); 168 fieldNames.add("predicted"); 169 fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector( 170 PrimitiveObjectInspectorFactory.writableFloatObjectInspector)); 171 return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); 172 } 173 174 @Override 175 public void process(Object[] args) throws HiveException { 176 if (mapToModel == null) { 177 this.mapToModel = new HashMap<String, Booster>(); 178 this.rowBuffer = new HashMap<String, List<LabeledPointWithRowId>>(); 179 } 180 if (args[1] == null) { 181 return; 182 } 183 184 String modelId = 185 PrimitiveObjectInspectorUtils.getString(nonNullArgument(args, 2), modelIdOI); 186 Booster model = mapToModel.get(modelId); 187 if (model == null) { 188 Text arg3 = modelOI.getPrimitiveWritableObject(nonNullArgument(args, 3)); 189 model = XGBoostUtils.deserializeBooster(arg3); 190 mapToModel.put(modelId, model); 191 } 192 193 List<LabeledPointWithRowId> rowBatch = rowBuffer.get(modelId); 194 if (rowBatch == null) { 195 rowBatch = new ArrayList<LabeledPointWithRowId>(_batchSize); 196 rowBuffer.put(modelId, rowBatch); 197 } 198 LabeledPointWithRowId row = parseRow(args); 199 rowBatch.add(row); 200 if (rowBatch.size() >= _batchSize) { 201 predictAndFlush(model, rowBatch); 202 } 203 } 204 205 @Nonnull 206 private LabeledPointWithRowId parseRow(@Nonnull Object[] args) throws UDFArgumentException { 207 final Writable rowId = HiveUtils.copyToWritable(nonNullArgument(args, 0), rowIdOI); 208 209 final Object arg1 = args[1]; 210 if (denseFeatures) { 211 return parseDenseFeatures(rowId, arg1, featureListOI, featureElemOI); 212 } else { 213 return parseSparseFeatures(rowId, arg1, featureListOI); 214 } 215 } 216 217 @Nonnull 218 private static LabeledPointWithRowId parseDenseFeatures(@Nonnull final Writable rowId, 219 @Nonnull final Object argObj, @Nonnull final ListObjectInspector featureListOI, 220 @Nonnull final PrimitiveObjectInspector featureElemOI) throws UDFArgumentException { 221 final int size = featureListOI.getListLength(argObj); 222 223 final float[] values = new float[size]; 224 for (int i = 0; i < size; i++) { 225 final Object o = featureListOI.getListElement(argObj, i); 226 if (o == null) { 227 values[i] = Float.NaN; 228 } else { 229 float v = PrimitiveObjectInspectorUtils.getFloat(o, featureElemOI); 230 values[i] = v; 231 } 232 } 233 234 return new LabeledPointWithRowId(rowId, /* dummy label */ 0.f, null, values); 235 236 } 237 238 @Nonnull 239 private static LabeledPointWithRowId parseSparseFeatures(@Nonnull final Writable rowId, 240 @Nonnull final Object argObj, @Nonnull final ListObjectInspector featureListOI) 241 throws UDFArgumentException { 242 final int size = featureListOI.getListLength(argObj); 243 final IntArrayList indices = new IntArrayList(size); 244 final FloatArrayList values = new FloatArrayList(size); 245 246 for (int i = 0; i < size; i++) { 247 Object f = featureListOI.getListElement(argObj, i); 248 if (f == null) { 249 continue; 250 } 251 final String str = f.toString(); 252 final int pos = str.indexOf(':'); 253 if (pos < 1) { 254 throw new UDFArgumentException("Invalid feature format: " + str); 255 } 256 final int index; 257 final float value; 258 try { 259 index = Integer.parseInt(str.substring(0, pos)); 260 value = Float.parseFloat(str.substring(pos + 1)); 261 } catch (NumberFormatException e) { 262 throw new UDFArgumentException("Failed to parse a feature value: " + str); 263 } 264 indices.add(index); 265 values.add(value); 266 } 267 268 return new LabeledPointWithRowId(rowId, /* dummy label */ 0.f, indices.toArray(), 269 values.toArray()); 270 } 271 272 273 @Override 274 public void close() throws HiveException { 275 for (Entry<String, List<LabeledPointWithRowId>> e : rowBuffer.entrySet()) { 276 String modelId = e.getKey(); 277 List<LabeledPointWithRowId> rowBatch = e.getValue(); 278 if (rowBatch.isEmpty()) { 279 continue; 280 } 281 final Booster model = Objects.requireNonNull(mapToModel.get(modelId)); 282 try { 283 predictAndFlush(model, rowBatch); 284 } finally { 285 XGBoostUtils.close(model); 286 } 287 } 288 this.rowBuffer = null; 289 this.mapToModel = null; 290 } 291 292 private void predictAndFlush(@Nonnull final Booster model, 293 @Nonnull final List<LabeledPointWithRowId> rowBatch) throws HiveException { 294 DMatrix testData = null; 295 final float[][] predicted; 296 try { 297 testData = XGBoostUtils.createDMatrix(rowBatch); 298 predicted = model.predict(testData); 299 } catch (XGBoostError e) { 300 throw new HiveException("Exception caused at prediction", e); 301 } finally { 302 XGBoostUtils.close(testData); 303 } 304 forwardPredicted(rowBatch, predicted); 305 rowBatch.clear(); 306 } 307 308 private void forwardPredicted(@Nonnull final List<LabeledPointWithRowId> rowBatch, 309 @Nonnull final float[][] predicted) throws HiveException { 310 if (rowBatch.size() != predicted.length) { 311 throw new HiveException(String.format("buf.size() = %d but predicted.length = %d", 312 rowBatch.size(), predicted.length)); 313 } 314 if (predicted.length == 0) { 315 return; 316 } 317 318 final int ncols = predicted[0].length; 319 final List<FloatWritable> list = WritableUtils.newFloatList(ncols); 320 321 final Object[] forwardObj = this._forwardObj; 322 forwardObj[1] = list; 323 324 for (int i = 0; i < predicted.length; i++) { 325 Writable rowId = Objects.requireNonNull(rowBatch.get(i)).getRowId(); 326 forwardObj[0] = rowId; 327 WritableUtils.setValues(predicted[i], list); 328 forward(forwardObj); 329 } 330 } 331 332 public static final class LabeledPointWithRowId extends LabeledPoint { 333 private static final long serialVersionUID = -7150841669515184648L; 334 335 @Nonnull 336 final Writable rowId; 337 338 LabeledPointWithRowId(@Nonnull Writable rowId, float label, @Nullable int[] indices, 339 @Nonnull float[] values) { 340 super(label, indices, values); 341 this.rowId = rowId; 342 } 343 344 @Nonnull 345 public Writable getRowId() { 346 return rowId; 347 } 348 } 349 350}