001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package hivemall.xgboost; 020 021import biz.k11i.xgboost.Predictor; 022import biz.k11i.xgboost.util.FVec; 023import hivemall.UDTFWithOptions; 024import hivemall.utils.hadoop.HiveUtils; 025import hivemall.utils.hadoop.WritableUtils; 026import hivemall.xgboost.utils.XGBoostUtils; 027 028import java.util.ArrayList; 029import java.util.HashMap; 030import java.util.List; 031import java.util.Map; 032 033import javax.annotation.Nonnull; 034import javax.annotation.Nullable; 035 036import org.apache.commons.cli.CommandLine; 037import org.apache.commons.cli.Options; 038import org.apache.hadoop.hive.ql.exec.Description; 039import org.apache.hadoop.hive.ql.exec.UDFArgumentException; 040import org.apache.hadoop.hive.ql.metadata.HiveException; 041import org.apache.hadoop.hive.serde2.io.DoubleWritable; 042import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; 043import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; 044import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; 045import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; 046import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; 047import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; 048import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; 049import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; 050import org.apache.hadoop.io.Text; 051import org.apache.hadoop.io.Writable; 052 053//@formatter:off 054@Description(name = "xgboost_predict", 055 value = "_FUNC_(PRIMITIVE rowid, array<string|double> features, string model_id, array<string> pred_model [, string options]) " 056 + "- Returns a prediction result as (string rowid, array<double> predicted)", 057 extended = "select\n" + 058 " rowid, \n" + 059 " array_avg(predicted) as predicted,\n" + 060 " avg(predicted[0]) as predicted0\n" + 061 "from (\n" + 062 " select\n" + 063 " xgboost_predict(rowid, features, model_id, model) as (rowid, predicted)\n" + 064 " from\n" + 065 " xgb_model l\n" + 066 " LEFT OUTER JOIN xgb_input r\n" + 067 ") t\n" + 068 "group by rowid;") 069//@formatter:on 070public class XGBoostOnlinePredictUDTF extends UDTFWithOptions { 071 072 // For input parameters 073 private PrimitiveObjectInspector rowIdOI; 074 private ListObjectInspector featureListOI; 075 private boolean denseFeatures; 076 @Nullable 077 private PrimitiveObjectInspector featureElemOI; 078 private StringObjectInspector modelIdOI; 079 private StringObjectInspector modelOI; 080 081 // For input buffer 082 @Nullable 083 private transient Map<String, Predictor> mapToModel; 084 085 @Nonnull 086 protected transient final Object[] _forwardObj; 087 @Nullable 088 protected transient List<DoubleWritable> _predictedCache; 089 090 public XGBoostOnlinePredictUDTF() { 091 this(new Object[2]); 092 } 093 094 protected XGBoostOnlinePredictUDTF(@Nonnull Object[] forwardObj) { 095 super(); 096 this._forwardObj = forwardObj; 097 } 098 099 @Override 100 protected Options getOptions() { 101 Options opts = new Options(); 102 // not yet supported 103 return opts; 104 } 105 106 @Override 107 protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { 108 CommandLine cl = null; 109 if (argOIs.length >= 5) { 110 String rawArgs = HiveUtils.getConstString(argOIs, 4); 111 cl = parseOptions(rawArgs); 112 } 113 return cl; 114 } 115 116 @Override 117 public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) 118 throws UDFArgumentException { 119 if (argOIs.length != 4 && argOIs.length != 5) { 120 showHelp("Invalid argment length=" + argOIs.length); 121 } 122 processOptions(argOIs); 123 124 this.rowIdOI = HiveUtils.asPrimitiveObjectInspector(argOIs, 0); 125 ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 1); 126 this.featureListOI = listOI; 127 ObjectInspector elemOI = listOI.getListElementObjectInspector(); 128 if (HiveUtils.isNumberOI(elemOI)) { 129 this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI); 130 this.denseFeatures = true; 131 } else if (HiveUtils.isStringOI(elemOI)) { 132 this.denseFeatures = false; 133 } else { 134 throw new UDFArgumentException( 135 "Expected array<string|double> for the 2nd argment but got an unexpected features type: " 136 + listOI.getTypeName()); 137 } 138 this.modelIdOI = HiveUtils.asStringOI(argOIs, 2); 139 this.modelOI = HiveUtils.asStringOI(argOIs, 3); 140 return getReturnOI(rowIdOI); 141 } 142 143 /** Override this to output predicted results depending on a task type */ 144 /** Return (primitive rowid, array<double> predicted) as a result */ 145 @Nonnull 146 protected StructObjectInspector getReturnOI(@Nonnull PrimitiveObjectInspector rowIdOI) { 147 List<String> fieldNames = new ArrayList<>(2); 148 List<ObjectInspector> fieldOIs = new ArrayList<>(2); 149 fieldNames.add("rowid"); 150 fieldOIs.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( 151 rowIdOI.getPrimitiveCategory())); 152 fieldNames.add("predicted"); 153 fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector( 154 PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); 155 return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); 156 } 157 158 @Override 159 public void process(Object[] args) throws HiveException { 160 if (mapToModel == null) { 161 this.mapToModel = new HashMap<String, Predictor>(); 162 } 163 if (args[1] == null) {// features is null 164 return; 165 } 166 167 String modelId = 168 PrimitiveObjectInspectorUtils.getString(nonNullArgument(args, 2), modelIdOI); 169 Predictor model = mapToModel.get(modelId); 170 if (model == null) { 171 Text arg3 = modelOI.getPrimitiveWritableObject(nonNullArgument(args, 3)); 172 model = XGBoostUtils.loadPredictor(arg3); 173 mapToModel.put(modelId, model); 174 } 175 176 Writable rowId = HiveUtils.copyToWritable(nonNullArgument(args, 0), rowIdOI); 177 FVec features = denseFeatures ? parseDenseFeatures(args[1]) 178 : parseSparseFeatures(featureListOI.getList(args[1])); 179 180 predictAndForward(model, rowId, features); 181 } 182 183 @Nonnull 184 private FVec parseDenseFeatures(@Nonnull Object argObj) throws UDFArgumentException { 185 final int length = featureListOI.getListLength(argObj); 186 final double[] values = new double[length]; 187 for (int i = 0; i < length; i++) { 188 final Object o = featureListOI.getListElement(argObj, i); 189 final double v; 190 if (o == null) { 191 v = Double.NaN; 192 } else { 193 v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI); 194 } 195 values[i] = v; 196 197 } 198 return FVec.Transformer.fromArray(values, false); 199 } 200 201 @Nonnull 202 private static FVec parseSparseFeatures(@Nonnull final List<?> featureList) 203 throws UDFArgumentException { 204 final Map<Integer, Double> map = new HashMap<>((int) (featureList.size() * 1.5)); 205 for (Object f : featureList) { 206 if (f == null) { 207 continue; 208 } 209 String str = f.toString(); 210 final int pos = str.indexOf(':'); 211 if (pos < 1) { 212 throw new UDFArgumentException("Invalid feature format: " + str); 213 } 214 final int index; 215 final double value; 216 try { 217 index = Integer.parseInt(str.substring(0, pos)); 218 value = Double.parseDouble(str.substring(pos + 1)); 219 } catch (NumberFormatException e) { 220 throw new UDFArgumentException("Failed to parse a feature value: " + str); 221 } 222 map.put(index, value); 223 } 224 225 return FVec.Transformer.fromMap(map); 226 } 227 228 private void predictAndForward(@Nonnull final Predictor model, @Nonnull final Writable rowId, 229 @Nonnull final FVec features) throws HiveException { 230 double[] predicted = model.predict(features); 231 // predicted[0] has 232 // - probability ("binary:logistic") 233 // - class label ("multi:softmax") 234 forwardPredicted(rowId, predicted); 235 } 236 237 protected void forwardPredicted(@Nonnull final Writable rowId, 238 @Nonnull final double[] predicted) throws HiveException { 239 List<DoubleWritable> list = WritableUtils.toWritableList(predicted, _predictedCache); 240 this._predictedCache = list; 241 Object[] forwardObj = this._forwardObj; 242 forwardObj[0] = rowId; 243 forwardObj[1] = list; 244 forward(forwardObj); 245 } 246 247 @Override 248 public void close() throws HiveException { 249 this.mapToModel = null; 250 } 251 252}