001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package hivemall.xgboost;
020
021import java.util.ArrayList;
022import java.util.List;
023
024import javax.annotation.Nonnull;
025
026import org.apache.hadoop.hive.ql.exec.Description;
027import org.apache.hadoop.hive.ql.metadata.HiveException;
028import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
029import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
030import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
031import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
032import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
033import org.apache.hadoop.io.Writable;
034
035//@formatter:off
036@Description(name = "xgboost_predict_triple",
037        value = "_FUNC_(PRIMITIVE rowid, array<string|double> features, string model_id, array<string> pred_model [, string options]) "
038                + "- Returns a prediction result as (string rowid, string label, double probability)",
039        extended = "select\n" + 
040                "  rowid,\n" + 
041                "  label,\n" + 
042                "  avg(prob) as prob\n" + 
043                "from (\n" + 
044                "  select\n" + 
045                "    xgboost_predict_triple(rowid, features, model_id, model) as (rowid, label, prob)\n" + 
046                "  from\n" + 
047                "    xgb_model l\n" + 
048                "    LEFT OUTER JOIN xgb_input r\n" + 
049                ") t\n" + 
050                "group by rowid, label;")
051//@formatter:on
052public final class XGBoostPredictTripleUDTF extends XGBoostOnlinePredictUDTF {
053
054    public XGBoostPredictTripleUDTF() {
055        super(new Object[3]);
056    }
057
058    /** Return (string rowid, int label, double probability) as a result */
059    @Override
060    protected StructObjectInspector getReturnOI(@Nonnull PrimitiveObjectInspector rowIdOI) {
061        List<String> fieldNames = new ArrayList<>(3);
062        List<ObjectInspector> fieldOIs = new ArrayList<>(3);
063        fieldNames.add("rowid");
064        fieldOIs.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(
065            rowIdOI.getPrimitiveCategory()));
066        fieldNames.add("label");
067        fieldOIs.add(PrimitiveObjectInspectorFactory.javaIntObjectInspector);
068        fieldNames.add("proba");
069        fieldOIs.add(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector);
070
071        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
072    }
073
074    @Override
075    protected void forwardPredicted(@Nonnull Writable rowId, @Nonnull double[] predicted)
076            throws HiveException {
077        final Object[] forwardObj = _forwardObj;
078        forwardObj[0] = rowId;
079        for (int j = 0, ncols = predicted.length; j < ncols; j++) {
080            forwardObj[1] = Integer.valueOf(j);
081            forwardObj[2] = Double.valueOf(predicted[j]);
082            forward(forwardObj);
083        }
084    }
085
086}