001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package hivemall.xgboost; 020 021import java.util.ArrayList; 022import java.util.List; 023 024import javax.annotation.Nonnull; 025 026import org.apache.hadoop.hive.ql.exec.Description; 027import org.apache.hadoop.hive.ql.metadata.HiveException; 028import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; 029import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; 030import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; 031import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; 032import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; 033import org.apache.hadoop.io.Writable; 034 035//@formatter:off 036@Description(name = "xgboost_predict_triple", 037 value = "_FUNC_(PRIMITIVE rowid, array<string|double> features, string model_id, array<string> pred_model [, string options]) " 038 + "- Returns a prediction result as (string rowid, string label, double probability)", 039 extended = "select\n" + 040 " rowid,\n" + 041 " label,\n" + 042 " avg(prob) as prob\n" + 043 "from (\n" + 044 " select\n" + 045 " xgboost_predict_triple(rowid, features, model_id, model) as (rowid, label, prob)\n" + 046 " from\n" + 047 " xgb_model l\n" + 048 " LEFT OUTER JOIN xgb_input r\n" + 049 ") t\n" + 050 "group by rowid, label;") 051//@formatter:on 052public final class XGBoostPredictTripleUDTF extends XGBoostOnlinePredictUDTF { 053 054 public XGBoostPredictTripleUDTF() { 055 super(new Object[3]); 056 } 057 058 /** Return (string rowid, int label, double probability) as a result */ 059 @Override 060 protected StructObjectInspector getReturnOI(@Nonnull PrimitiveObjectInspector rowIdOI) { 061 List<String> fieldNames = new ArrayList<>(3); 062 List<ObjectInspector> fieldOIs = new ArrayList<>(3); 063 fieldNames.add("rowid"); 064 fieldOIs.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( 065 rowIdOI.getPrimitiveCategory())); 066 fieldNames.add("label"); 067 fieldOIs.add(PrimitiveObjectInspectorFactory.javaIntObjectInspector); 068 fieldNames.add("proba"); 069 fieldOIs.add(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector); 070 071 return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); 072 } 073 074 @Override 075 protected void forwardPredicted(@Nonnull Writable rowId, @Nonnull double[] predicted) 076 throws HiveException { 077 final Object[] forwardObj = _forwardObj; 078 forwardObj[0] = rowId; 079 for (int j = 0, ncols = predicted.length; j < ncols; j++) { 080 forwardObj[1] = Integer.valueOf(j); 081 forwardObj[2] = Double.valueOf(predicted[j]); 082 forward(forwardObj); 083 } 084 } 085 086}