001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package hivemall.xgboost;
020
021import java.util.ArrayList;
022import java.util.List;
023
024import javax.annotation.Nonnull;
025
026import org.apache.hadoop.hive.ql.exec.Description;
027import org.apache.hadoop.hive.ql.metadata.HiveException;
028import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
029import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
030import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
031import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
032import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
033import org.apache.hadoop.io.Writable;
034
035//@formatter:off
036@Description(name = "xgboost_predict_one",
037        value = "_FUNC_(PRIMITIVE rowid, array<string|double> features, string model_id, array<string> pred_model [, string options]) "
038                + "- Returns a prediction result as (string rowid, double predicted)",
039        extended = "select\n" + 
040                "  rowid, \n" + 
041                "  avg(predicted) as predicted\n" + 
042                "from (\n" + 
043                "  select\n" + 
044                "    xgboost_predict_one(rowid, features, model_id, model) as (rowid, predicted)\n" + 
045                "  from\n" + 
046                "    xgb_model l\n" + 
047                "    LEFT OUTER JOIN xgb_input r\n" + 
048                ") t\n" + 
049                "group by rowid;")
050//@formatter:on
051public final class XGBoostPredictOneUDTF extends XGBoostOnlinePredictUDTF {
052
053    public XGBoostPredictOneUDTF() {
054        super(new Object[2]);
055    }
056
057    /** Return (string rowid, double predicted) as a result */
058    @Override
059    protected StructObjectInspector getReturnOI(@Nonnull PrimitiveObjectInspector rowIdOI) {
060        List<String> fieldNames = new ArrayList<>(2);
061        List<ObjectInspector> fieldOIs = new ArrayList<>(2);
062        fieldNames.add("rowid");
063        fieldOIs.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(
064            rowIdOI.getPrimitiveCategory()));
065        fieldNames.add("predicted");
066        fieldOIs.add(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector);
067        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
068    }
069
070    @Override
071    protected void forwardPredicted(@Nonnull Writable rowId, @Nonnull double[] predicted)
072            throws HiveException {
073        final Object[] forwardObj = _forwardObj;
074        forwardObj[0] = rowId;
075        forwardObj[1] = Double.valueOf(predicted[0]);
076        forward(forwardObj);
077    }
078
079}