001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package hivemall.xgboost.utils; 020 021import biz.k11i.xgboost.Predictor; 022import biz.k11i.xgboost.util.FVec; 023import hivemall.utils.io.FastByteArrayInputStream; 024import hivemall.utils.io.IOUtils; 025import hivemall.xgboost.XGBoostBatchPredictUDTF.LabeledPointWithRowId; 026import ml.dmlc.xgboost4j.LabeledPoint; 027import ml.dmlc.xgboost4j.java.Booster; 028import ml.dmlc.xgboost4j.java.DMatrix; 029import ml.dmlc.xgboost4j.java.XGBoost; 030import ml.dmlc.xgboost4j.java.XGBoostError; 031 032import java.io.IOException; 033import java.io.InputStream; 034import java.lang.reflect.Constructor; 035import java.lang.reflect.InvocationTargetException; 036import java.util.ArrayList; 037import java.util.HashMap; 038import java.util.List; 039import java.util.Map; 040import java.util.Properties; 041 042import javax.annotation.Nonnull; 043import javax.annotation.Nullable; 044 045import org.apache.hadoop.hive.ql.metadata.HiveException; 046import org.apache.hadoop.io.Text; 047 048public final class XGBoostUtils { 049 050 private XGBoostUtils() {} 051 052 @Nonnull 053 public static String getVersion() throws HiveException { 054 Properties props = new Properties(); 055 try (InputStream versionResourceFile = 056 Thread.currentThread().getContextClassLoader().getResourceAsStream( 057 "xgboost4j-version.properties")) { 058 props.load(versionResourceFile); 059 } catch (IOException e) { 060 throw new HiveException("Failed to load xgboost4j-version.properties", e); 061 } 062 return props.getProperty("version", "<unknown>"); 063 } 064 065 @Nonnull 066 public static DMatrix createDMatrix(@Nonnull final List<LabeledPointWithRowId> data) 067 throws XGBoostError { 068 final List<LabeledPoint> points = new ArrayList<>(data.size()); 069 for (LabeledPointWithRowId d : data) { 070 points.add(d); 071 } 072 return new DMatrix(points.iterator(), ""); 073 } 074 075 @Nonnull 076 public static Booster createBooster(@Nonnull DMatrix matrix, 077 @Nonnull Map<String, Object> params) throws NoSuchMethodException, XGBoostError, 078 IllegalAccessException, InvocationTargetException, InstantiationException { 079 Class<?>[] args = {Map.class, DMatrix[].class}; 080 Constructor<Booster> ctor = Booster.class.getDeclaredConstructor(args); 081 ctor.setAccessible(true); 082 return ctor.newInstance(new Object[] {params, new DMatrix[] {matrix}}); 083 } 084 085 public static void close(@Nullable final DMatrix matrix) { 086 if (matrix == null) { 087 return; 088 } 089 try { 090 matrix.dispose(); 091 } catch (Throwable e) { 092 ; 093 } 094 } 095 096 public static void close(@Nullable final Booster booster) { 097 if (booster == null) { 098 return; 099 } 100 try { 101 booster.dispose(); 102 } catch (Throwable e) { 103 ; 104 } 105 } 106 107 @Nonnull 108 public static Text serializeBooster(@Nonnull final Booster booster) throws HiveException { 109 try { 110 byte[] b = IOUtils.toCompressedText(booster.toByteArray()); 111 return new Text(b); 112 } catch (Throwable e) { 113 throw new HiveException("Failed to serialize a booster", e); 114 } 115 } 116 117 @Nonnull 118 public static Booster deserializeBooster(@Nonnull final Text model) throws HiveException { 119 try { 120 byte[] b = IOUtils.fromCompressedText(model.getBytes(), model.getLength()); 121 return XGBoost.loadModel(new FastByteArrayInputStream(b)); 122 } catch (Throwable e) { 123 throw new HiveException("Failed to deserialize a booster", e); 124 } 125 } 126 127 @Nonnull 128 public static Predictor loadPredictor(@Nonnull final Text model) throws HiveException { 129 try { 130 byte[] b = IOUtils.fromCompressedText(model.getBytes(), model.getLength()); 131 return new Predictor(new FastByteArrayInputStream(b)); 132 } catch (Throwable e) { 133 throw new HiveException("Failed to create a predictor", e); 134 } 135 } 136 137 @Nonnull 138 public static FVec parseRowAsFVec(@Nonnull final String[] row, final int start, final int end) { 139 final Map<Integer, Float> map = new HashMap<>((int) (row.length * 1.5)); 140 for (int i = start; i < end; i++) { 141 String f = row[i]; 142 if (f == null) { 143 continue; 144 } 145 String str = f.toString(); 146 final int pos = str.indexOf(':'); 147 if (pos < 1) { 148 throw new IllegalArgumentException("Invalid feature format: " + str); 149 } 150 final int index; 151 final float value; 152 try { 153 index = Integer.parseInt(str.substring(0, pos)); 154 value = Float.parseFloat(str.substring(pos + 1)); 155 } catch (NumberFormatException e) { 156 throw new IllegalArgumentException("Failed to parse a feature value: " + str); 157 } 158 map.put(index, value); 159 } 160 161 return FVec.Transformer.fromMap(map); 162 } 163 164}