001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package hivemall.xgboost;
020
021import biz.k11i.xgboost.Predictor;
022import biz.k11i.xgboost.util.FVec;
023import hivemall.UDTFWithOptions;
024import hivemall.utils.hadoop.HiveUtils;
025import hivemall.utils.hadoop.WritableUtils;
026import hivemall.xgboost.utils.XGBoostUtils;
027
028import java.util.ArrayList;
029import java.util.HashMap;
030import java.util.List;
031import java.util.Map;
032
033import javax.annotation.Nonnull;
034import javax.annotation.Nullable;
035
036import org.apache.commons.cli.CommandLine;
037import org.apache.commons.cli.Options;
038import org.apache.hadoop.hive.ql.exec.Description;
039import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
040import org.apache.hadoop.hive.ql.metadata.HiveException;
041import org.apache.hadoop.hive.serde2.io.DoubleWritable;
042import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
043import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
044import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
045import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
046import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
047import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
048import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
049import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
050import org.apache.hadoop.io.Text;
051import org.apache.hadoop.io.Writable;
052
053//@formatter:off
054@Description(name = "xgboost_predict",
055        value = "_FUNC_(PRIMITIVE rowid, array<string|double> features, string model_id, array<string> pred_model [, string options]) "
056                + "- Returns a prediction result as (string rowid, array<double> predicted)",
057        extended = "select\n" + 
058                "  rowid, \n" + 
059                "  array_avg(predicted) as predicted,\n" + 
060                "  avg(predicted[0]) as predicted0\n" + 
061                "from (\n" + 
062                "  select\n" + 
063                "    xgboost_predict(rowid, features, model_id, model) as (rowid, predicted)\n" + 
064                "  from\n" + 
065                "    xgb_model l\n" + 
066                "    LEFT OUTER JOIN xgb_input r\n" + 
067                ") t\n" + 
068                "group by rowid;")
069//@formatter:on
070public class XGBoostOnlinePredictUDTF extends UDTFWithOptions {
071
072    // For input parameters
073    private PrimitiveObjectInspector rowIdOI;
074    private ListObjectInspector featureListOI;
075    private boolean denseFeatures;
076    @Nullable
077    private PrimitiveObjectInspector featureElemOI;
078    private StringObjectInspector modelIdOI;
079    private StringObjectInspector modelOI;
080
081    // For input buffer
082    @Nullable
083    private transient Map<String, Predictor> mapToModel;
084
085    @Nonnull
086    protected transient final Object[] _forwardObj;
087    @Nullable
088    protected transient List<DoubleWritable> _predictedCache;
089
090    public XGBoostOnlinePredictUDTF() {
091        this(new Object[2]);
092    }
093
094    protected XGBoostOnlinePredictUDTF(@Nonnull Object[] forwardObj) {
095        super();
096        this._forwardObj = forwardObj;
097    }
098
099    @Override
100    protected Options getOptions() {
101        Options opts = new Options();
102        // not yet supported
103        return opts;
104    }
105
106    @Override
107    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
108        CommandLine cl = null;
109        if (argOIs.length >= 5) {
110            String rawArgs = HiveUtils.getConstString(argOIs, 4);
111            cl = parseOptions(rawArgs);
112        }
113        return cl;
114    }
115
116    @Override
117    public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs)
118            throws UDFArgumentException {
119        if (argOIs.length != 4 && argOIs.length != 5) {
120            showHelp("Invalid argment length=" + argOIs.length);
121        }
122        processOptions(argOIs);
123
124        this.rowIdOI = HiveUtils.asPrimitiveObjectInspector(argOIs, 0);
125        ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 1);
126        this.featureListOI = listOI;
127        ObjectInspector elemOI = listOI.getListElementObjectInspector();
128        if (HiveUtils.isNumberOI(elemOI)) {
129            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
130            this.denseFeatures = true;
131        } else if (HiveUtils.isStringOI(elemOI)) {
132            this.denseFeatures = false;
133        } else {
134            throw new UDFArgumentException(
135                "Expected array<string|double> for the 2nd argment but got an unexpected features type: "
136                        + listOI.getTypeName());
137        }
138        this.modelIdOI = HiveUtils.asStringOI(argOIs, 2);
139        this.modelOI = HiveUtils.asStringOI(argOIs, 3);
140        return getReturnOI(rowIdOI);
141    }
142
143    /** Override this to output predicted results depending on a task type */
144    /** Return (primitive rowid, array<double> predicted) as a result */
145    @Nonnull
146    protected StructObjectInspector getReturnOI(@Nonnull PrimitiveObjectInspector rowIdOI) {
147        List<String> fieldNames = new ArrayList<>(2);
148        List<ObjectInspector> fieldOIs = new ArrayList<>(2);
149        fieldNames.add("rowid");
150        fieldOIs.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(
151            rowIdOI.getPrimitiveCategory()));
152        fieldNames.add("predicted");
153        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
154            PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
155        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
156    }
157
158    @Override
159    public void process(Object[] args) throws HiveException {
160        if (mapToModel == null) {
161            this.mapToModel = new HashMap<String, Predictor>();
162        }
163        if (args[1] == null) {// features is null
164            return;
165        }
166
167        String modelId =
168                PrimitiveObjectInspectorUtils.getString(nonNullArgument(args, 2), modelIdOI);
169        Predictor model = mapToModel.get(modelId);
170        if (model == null) {
171            Text arg3 = modelOI.getPrimitiveWritableObject(nonNullArgument(args, 3));
172            model = XGBoostUtils.loadPredictor(arg3);
173            mapToModel.put(modelId, model);
174        }
175
176        Writable rowId = HiveUtils.copyToWritable(nonNullArgument(args, 0), rowIdOI);
177        FVec features = denseFeatures ? parseDenseFeatures(args[1])
178                : parseSparseFeatures(featureListOI.getList(args[1]));
179
180        predictAndForward(model, rowId, features);
181    }
182
183    @Nonnull
184    private FVec parseDenseFeatures(@Nonnull Object argObj) throws UDFArgumentException {
185        final int length = featureListOI.getListLength(argObj);
186        final double[] values = new double[length];
187        for (int i = 0; i < length; i++) {
188            final Object o = featureListOI.getListElement(argObj, i);
189            final double v;
190            if (o == null) {
191                v = Double.NaN;
192            } else {
193                v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI);
194            }
195            values[i] = v;
196
197        }
198        return FVec.Transformer.fromArray(values, false);
199    }
200
201    @Nonnull
202    private static FVec parseSparseFeatures(@Nonnull final List<?> featureList)
203            throws UDFArgumentException {
204        final Map<Integer, Double> map = new HashMap<>((int) (featureList.size() * 1.5));
205        for (Object f : featureList) {
206            if (f == null) {
207                continue;
208            }
209            String str = f.toString();
210            final int pos = str.indexOf(':');
211            if (pos < 1) {
212                throw new UDFArgumentException("Invalid feature format: " + str);
213            }
214            final int index;
215            final double value;
216            try {
217                index = Integer.parseInt(str.substring(0, pos));
218                value = Double.parseDouble(str.substring(pos + 1));
219            } catch (NumberFormatException e) {
220                throw new UDFArgumentException("Failed to parse a feature value: " + str);
221            }
222            map.put(index, value);
223        }
224
225        return FVec.Transformer.fromMap(map);
226    }
227
228    private void predictAndForward(@Nonnull final Predictor model, @Nonnull final Writable rowId,
229            @Nonnull final FVec features) throws HiveException {
230        double[] predicted = model.predict(features);
231        // predicted[0] has
232        //    - probability ("binary:logistic")
233        //    - class label ("multi:softmax")
234        forwardPredicted(rowId, predicted);
235    }
236
237    protected void forwardPredicted(@Nonnull final Writable rowId,
238            @Nonnull final double[] predicted) throws HiveException {
239        List<DoubleWritable> list = WritableUtils.toWritableList(predicted, _predictedCache);
240        this._predictedCache = list;
241        Object[] forwardObj = this._forwardObj;
242        forwardObj[0] = rowId;
243        forwardObj[1] = list;
244        forward(forwardObj);
245    }
246
247    @Override
248    public void close() throws HiveException {
249        this.mapToModel = null;
250    }
251
252}