001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package hivemall.xgboost;
020
021import hivemall.UDTFWithOptions;
022import hivemall.utils.collections.lists.FloatArrayList;
023import hivemall.utils.collections.lists.IntArrayList;
024import hivemall.utils.hadoop.HiveUtils;
025import hivemall.utils.hadoop.WritableUtils;
026import hivemall.utils.lang.Primitives;
027import hivemall.xgboost.utils.NativeLibLoader;
028import hivemall.xgboost.utils.XGBoostUtils;
029import ml.dmlc.xgboost4j.LabeledPoint;
030import ml.dmlc.xgboost4j.java.Booster;
031import ml.dmlc.xgboost4j.java.DMatrix;
032import ml.dmlc.xgboost4j.java.XGBoostError;
033
034import java.util.ArrayList;
035import java.util.HashMap;
036import java.util.List;
037import java.util.Map;
038import java.util.Map.Entry;
039import java.util.Objects;
040
041import javax.annotation.Nonnull;
042import javax.annotation.Nullable;
043
044import org.apache.commons.cli.CommandLine;
045import org.apache.commons.cli.Options;
046import org.apache.hadoop.hive.ql.exec.Description;
047import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
048import org.apache.hadoop.hive.ql.metadata.HiveException;
049import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
050import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
051import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
052import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
053import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
054import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
055import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
056import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
057import org.apache.hadoop.io.FloatWritable;
058import org.apache.hadoop.io.Text;
059import org.apache.hadoop.io.Writable;
060
061//@formatter:off
062@Description(name = "xgboost_batch_predict",
063        value = "_FUNC_(PRIMITIVE rowid, array<string|double> features, string model_id, array<string> pred_model [, string options]) "
064                + "- Returns a prediction result as (string rowid, array<double> predicted)",
065        extended = "select\n" + 
066                "  rowid, \n" + 
067                "  array_avg(predicted) as predicted,\n" + 
068                "  avg(predicted[0]) as predicted0\n" + 
069                "from (\n" + 
070                "  select\n" + 
071                "    xgboost_batch_predict(rowid, features, model_id, model) as (rowid, predicted)\n" + 
072                "  from\n" + 
073                "    xgb_model l\n" + 
074                "    LEFT OUTER JOIN xgb_input r\n" + 
075                ") t\n" + 
076                "group by rowid;")
077//@formatter:on
078public final class XGBoostBatchPredictUDTF extends UDTFWithOptions {
079
080    // For input parameters
081    private PrimitiveObjectInspector rowIdOI;
082    private ListObjectInspector featureListOI;
083    private boolean denseFeatures;
084    @Nullable
085    private PrimitiveObjectInspector featureElemOI;
086    private StringObjectInspector modelIdOI;
087    private StringObjectInspector modelOI;
088
089    // For input buffer
090    private transient Map<String, Booster> mapToModel;
091    private transient Map<String, List<LabeledPointWithRowId>> rowBuffer;
092
093    private int _batchSize;
094
095    @Nonnull
096    protected transient final Object[] _forwardObj;
097
098    // Settings for the XGBoost native library
099    static {
100        NativeLibLoader.initXGBoost();
101    }
102
103    public XGBoostBatchPredictUDTF() {
104        super();
105        this._forwardObj = new Object[2];
106    }
107
108    @Override
109    protected Options getOptions() {
110        Options opts = new Options();
111        opts.addOption("batch_size", true, "Number of rows to predict together [default: 128]");
112        return opts;
113    }
114
115    @Override
116    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
117        int batchSize = 128;
118        CommandLine cl = null;
119        if (argOIs.length >= 5) {
120            String rawArgs = HiveUtils.getConstString(argOIs, 4);
121            cl = parseOptions(rawArgs);
122            batchSize = Primitives.parseInt(cl.getOptionValue("batch_size"), batchSize);
123            if (batchSize < 1) {
124                throw new UDFArgumentException("batch_size must be greater than 0: " + batchSize);
125            }
126        }
127        this._batchSize = batchSize;
128        return cl;
129    }
130
131    @Override
132    public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs)
133            throws UDFArgumentException {
134        if (argOIs.length != 4 && argOIs.length != 5) {
135            showHelp("Invalid argment length=" + argOIs.length);
136        }
137        processOptions(argOIs);
138
139        this.rowIdOI = HiveUtils.asPrimitiveObjectInspector(argOIs, 0);
140
141        this.featureListOI = HiveUtils.asListOI(argOIs, 1);
142        ObjectInspector elemOI = featureListOI.getListElementObjectInspector();
143        if (HiveUtils.isNumberOI(elemOI)) {
144            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
145            this.denseFeatures = true;
146        } else if (HiveUtils.isStringOI(elemOI)) {
147            this.denseFeatures = false;
148        } else {
149            throw new UDFArgumentException(
150                "Expected array<string|double> for the 2nd argment but got an unexpected features type: "
151                        + featureListOI.getTypeName());
152        }
153        this.modelIdOI = HiveUtils.asStringOI(argOIs, 2);
154        this.modelOI = HiveUtils.asStringOI(argOIs, 3);
155
156        return getReturnOI(rowIdOI);
157    }
158
159    /** Override this to output predicted results depending on a task type */
160    /** Return (string rowid, array<double> predicted) as a result */
161    @Nonnull
162    protected StructObjectInspector getReturnOI(@Nonnull PrimitiveObjectInspector rowIdOI) {
163        List<String> fieldNames = new ArrayList<>(2);
164        List<ObjectInspector> fieldOIs = new ArrayList<>(2);
165        fieldNames.add("rowid");
166        fieldOIs.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(
167            rowIdOI.getPrimitiveCategory()));
168        fieldNames.add("predicted");
169        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
170            PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
171        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
172    }
173
174    @Override
175    public void process(Object[] args) throws HiveException {
176        if (mapToModel == null) {
177            this.mapToModel = new HashMap<String, Booster>();
178            this.rowBuffer = new HashMap<String, List<LabeledPointWithRowId>>();
179        }
180        if (args[1] == null) {
181            return;
182        }
183
184        String modelId =
185                PrimitiveObjectInspectorUtils.getString(nonNullArgument(args, 2), modelIdOI);
186        Booster model = mapToModel.get(modelId);
187        if (model == null) {
188            Text arg3 = modelOI.getPrimitiveWritableObject(nonNullArgument(args, 3));
189            model = XGBoostUtils.deserializeBooster(arg3);
190            mapToModel.put(modelId, model);
191        }
192
193        List<LabeledPointWithRowId> rowBatch = rowBuffer.get(modelId);
194        if (rowBatch == null) {
195            rowBatch = new ArrayList<LabeledPointWithRowId>(_batchSize);
196            rowBuffer.put(modelId, rowBatch);
197        }
198        LabeledPointWithRowId row = parseRow(args);
199        rowBatch.add(row);
200        if (rowBatch.size() >= _batchSize) {
201            predictAndFlush(model, rowBatch);
202        }
203    }
204
205    @Nonnull
206    private LabeledPointWithRowId parseRow(@Nonnull Object[] args) throws UDFArgumentException {
207        final Writable rowId = HiveUtils.copyToWritable(nonNullArgument(args, 0), rowIdOI);
208
209        final Object arg1 = args[1];
210        if (denseFeatures) {
211            return parseDenseFeatures(rowId, arg1, featureListOI, featureElemOI);
212        } else {
213            return parseSparseFeatures(rowId, arg1, featureListOI);
214        }
215    }
216
217    @Nonnull
218    private static LabeledPointWithRowId parseDenseFeatures(@Nonnull final Writable rowId,
219            @Nonnull final Object argObj, @Nonnull final ListObjectInspector featureListOI,
220            @Nonnull final PrimitiveObjectInspector featureElemOI) throws UDFArgumentException {
221        final int size = featureListOI.getListLength(argObj);
222
223        final float[] values = new float[size];
224        for (int i = 0; i < size; i++) {
225            final Object o = featureListOI.getListElement(argObj, i);
226            if (o == null) {
227                values[i] = Float.NaN;
228            } else {
229                float v = PrimitiveObjectInspectorUtils.getFloat(o, featureElemOI);
230                values[i] = v;
231            }
232        }
233
234        return new LabeledPointWithRowId(rowId, /* dummy label */ 0.f, null, values);
235
236    }
237
238    @Nonnull
239    private static LabeledPointWithRowId parseSparseFeatures(@Nonnull final Writable rowId,
240            @Nonnull final Object argObj, @Nonnull final ListObjectInspector featureListOI)
241            throws UDFArgumentException {
242        final int size = featureListOI.getListLength(argObj);
243        final IntArrayList indices = new IntArrayList(size);
244        final FloatArrayList values = new FloatArrayList(size);
245
246        for (int i = 0; i < size; i++) {
247            Object f = featureListOI.getListElement(argObj, i);
248            if (f == null) {
249                continue;
250            }
251            final String str = f.toString();
252            final int pos = str.indexOf(':');
253            if (pos < 1) {
254                throw new UDFArgumentException("Invalid feature format: " + str);
255            }
256            final int index;
257            final float value;
258            try {
259                index = Integer.parseInt(str.substring(0, pos));
260                value = Float.parseFloat(str.substring(pos + 1));
261            } catch (NumberFormatException e) {
262                throw new UDFArgumentException("Failed to parse a feature value: " + str);
263            }
264            indices.add(index);
265            values.add(value);
266        }
267
268        return new LabeledPointWithRowId(rowId, /* dummy label */ 0.f, indices.toArray(),
269            values.toArray());
270    }
271
272
273    @Override
274    public void close() throws HiveException {
275        for (Entry<String, List<LabeledPointWithRowId>> e : rowBuffer.entrySet()) {
276            String modelId = e.getKey();
277            List<LabeledPointWithRowId> rowBatch = e.getValue();
278            if (rowBatch.isEmpty()) {
279                continue;
280            }
281            final Booster model = Objects.requireNonNull(mapToModel.get(modelId));
282            try {
283                predictAndFlush(model, rowBatch);
284            } finally {
285                XGBoostUtils.close(model);
286            }
287        }
288        this.rowBuffer = null;
289        this.mapToModel = null;
290    }
291
292    private void predictAndFlush(@Nonnull final Booster model,
293            @Nonnull final List<LabeledPointWithRowId> rowBatch) throws HiveException {
294        DMatrix testData = null;
295        final float[][] predicted;
296        try {
297            testData = XGBoostUtils.createDMatrix(rowBatch);
298            predicted = model.predict(testData);
299        } catch (XGBoostError e) {
300            throw new HiveException("Exception caused at prediction", e);
301        } finally {
302            XGBoostUtils.close(testData);
303        }
304        forwardPredicted(rowBatch, predicted);
305        rowBatch.clear();
306    }
307
308    private void forwardPredicted(@Nonnull final List<LabeledPointWithRowId> rowBatch,
309            @Nonnull final float[][] predicted) throws HiveException {
310        if (rowBatch.size() != predicted.length) {
311            throw new HiveException(String.format("buf.size() = %d but predicted.length = %d",
312                rowBatch.size(), predicted.length));
313        }
314        if (predicted.length == 0) {
315            return;
316        }
317
318        final int ncols = predicted[0].length;
319        final List<FloatWritable> list = WritableUtils.newFloatList(ncols);
320
321        final Object[] forwardObj = this._forwardObj;
322        forwardObj[1] = list;
323
324        for (int i = 0; i < predicted.length; i++) {
325            Writable rowId = Objects.requireNonNull(rowBatch.get(i)).getRowId();
326            forwardObj[0] = rowId;
327            WritableUtils.setValues(predicted[i], list);
328            forward(forwardObj);
329        }
330    }
331
332    public static final class LabeledPointWithRowId extends LabeledPoint {
333        private static final long serialVersionUID = -7150841669515184648L;
334
335        @Nonnull
336        final Writable rowId;
337
338        LabeledPointWithRowId(@Nonnull Writable rowId, float label, @Nullable int[] indices,
339                @Nonnull float[] values) {
340            super(label, indices, values);
341            this.rowId = rowId;
342        }
343
344        @Nonnull
345        public Writable getRowId() {
346            return rowId;
347        }
348    }
349
350}