001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package hivemall.xgboost; 020 021import java.util.ArrayList; 022import java.util.List; 023 024import javax.annotation.Nonnull; 025 026import org.apache.hadoop.hive.ql.exec.Description; 027import org.apache.hadoop.hive.ql.metadata.HiveException; 028import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; 029import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; 030import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; 031import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; 032import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; 033import org.apache.hadoop.io.Writable; 034 035//@formatter:off 036@Description(name = "xgboost_predict_one", 037 value = "_FUNC_(PRIMITIVE rowid, array<string|double> features, string model_id, array<string> pred_model [, string options]) " 038 + "- Returns a prediction result as (string rowid, double predicted)", 039 extended = "select\n" + 040 " rowid, \n" + 041 " avg(predicted) as predicted\n" + 042 "from (\n" + 043 " select\n" + 044 " xgboost_predict_one(rowid, features, model_id, model) as (rowid, predicted)\n" + 045 " from\n" + 046 " xgb_model l\n" + 047 " LEFT OUTER JOIN xgb_input r\n" + 048 ") t\n" + 049 "group by rowid;") 050//@formatter:on 051public final class XGBoostPredictOneUDTF extends XGBoostOnlinePredictUDTF { 052 053 public XGBoostPredictOneUDTF() { 054 super(new Object[2]); 055 } 056 057 /** Return (string rowid, double predicted) as a result */ 058 @Override 059 protected StructObjectInspector getReturnOI(@Nonnull PrimitiveObjectInspector rowIdOI) { 060 List<String> fieldNames = new ArrayList<>(2); 061 List<ObjectInspector> fieldOIs = new ArrayList<>(2); 062 fieldNames.add("rowid"); 063 fieldOIs.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( 064 rowIdOI.getPrimitiveCategory())); 065 fieldNames.add("predicted"); 066 fieldOIs.add(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector); 067 return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); 068 } 069 070 @Override 071 protected void forwardPredicted(@Nonnull Writable rowId, @Nonnull double[] predicted) 072 throws HiveException { 073 final Object[] forwardObj = _forwardObj; 074 forwardObj[0] = rowId; 075 forwardObj[1] = Double.valueOf(predicted[0]); 076 forward(forwardObj); 077 } 078 079}