001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package hivemall.xgboost.utils;
020
021import biz.k11i.xgboost.Predictor;
022import biz.k11i.xgboost.util.FVec;
023import hivemall.utils.io.FastByteArrayInputStream;
024import hivemall.utils.io.IOUtils;
025import hivemall.xgboost.XGBoostBatchPredictUDTF.LabeledPointWithRowId;
026import ml.dmlc.xgboost4j.LabeledPoint;
027import ml.dmlc.xgboost4j.java.Booster;
028import ml.dmlc.xgboost4j.java.DMatrix;
029import ml.dmlc.xgboost4j.java.XGBoost;
030import ml.dmlc.xgboost4j.java.XGBoostError;
031
032import java.io.IOException;
033import java.io.InputStream;
034import java.lang.reflect.Constructor;
035import java.lang.reflect.InvocationTargetException;
036import java.util.ArrayList;
037import java.util.HashMap;
038import java.util.List;
039import java.util.Map;
040import java.util.Properties;
041
042import javax.annotation.Nonnull;
043import javax.annotation.Nullable;
044
045import org.apache.hadoop.hive.ql.metadata.HiveException;
046import org.apache.hadoop.io.Text;
047
048public final class XGBoostUtils {
049
050    private XGBoostUtils() {}
051
052    @Nonnull
053    public static String getVersion() throws HiveException {
054        Properties props = new Properties();
055        try (InputStream versionResourceFile =
056                Thread.currentThread().getContextClassLoader().getResourceAsStream(
057                    "xgboost4j-version.properties")) {
058            props.load(versionResourceFile);
059        } catch (IOException e) {
060            throw new HiveException("Failed to load xgboost4j-version.properties", e);
061        }
062        return props.getProperty("version", "<unknown>");
063    }
064
065    @Nonnull
066    public static DMatrix createDMatrix(@Nonnull final List<LabeledPointWithRowId> data)
067            throws XGBoostError {
068        final List<LabeledPoint> points = new ArrayList<>(data.size());
069        for (LabeledPointWithRowId d : data) {
070            points.add(d);
071        }
072        return new DMatrix(points.iterator(), "");
073    }
074
075    @Nonnull
076    public static Booster createBooster(@Nonnull DMatrix matrix,
077            @Nonnull Map<String, Object> params) throws NoSuchMethodException, XGBoostError,
078            IllegalAccessException, InvocationTargetException, InstantiationException {
079        Class<?>[] args = {Map.class, DMatrix[].class};
080        Constructor<Booster> ctor = Booster.class.getDeclaredConstructor(args);
081        ctor.setAccessible(true);
082        return ctor.newInstance(new Object[] {params, new DMatrix[] {matrix}});
083    }
084
085    public static void close(@Nullable final DMatrix matrix) {
086        if (matrix == null) {
087            return;
088        }
089        try {
090            matrix.dispose();
091        } catch (Throwable e) {
092            ;
093        }
094    }
095
096    public static void close(@Nullable final Booster booster) {
097        if (booster == null) {
098            return;
099        }
100        try {
101            booster.dispose();
102        } catch (Throwable e) {
103            ;
104        }
105    }
106
107    @Nonnull
108    public static Text serializeBooster(@Nonnull final Booster booster) throws HiveException {
109        try {
110            byte[] b = IOUtils.toCompressedText(booster.toByteArray());
111            return new Text(b);
112        } catch (Throwable e) {
113            throw new HiveException("Failed to serialize a booster", e);
114        }
115    }
116
117    @Nonnull
118    public static Booster deserializeBooster(@Nonnull final Text model) throws HiveException {
119        try {
120            byte[] b = IOUtils.fromCompressedText(model.getBytes(), model.getLength());
121            return XGBoost.loadModel(new FastByteArrayInputStream(b));
122        } catch (Throwable e) {
123            throw new HiveException("Failed to deserialize a booster", e);
124        }
125    }
126
127    @Nonnull
128    public static Predictor loadPredictor(@Nonnull final Text model) throws HiveException {
129        try {
130            byte[] b = IOUtils.fromCompressedText(model.getBytes(), model.getLength());
131            return new Predictor(new FastByteArrayInputStream(b));
132        } catch (Throwable e) {
133            throw new HiveException("Failed to create a predictor", e);
134        }
135    }
136
137    @Nonnull
138    public static FVec parseRowAsFVec(@Nonnull final String[] row, final int start, final int end) {
139        final Map<Integer, Float> map = new HashMap<>((int) (row.length * 1.5));
140        for (int i = start; i < end; i++) {
141            String f = row[i];
142            if (f == null) {
143                continue;
144            }
145            String str = f.toString();
146            final int pos = str.indexOf(':');
147            if (pos < 1) {
148                throw new IllegalArgumentException("Invalid feature format: " + str);
149            }
150            final int index;
151            final float value;
152            try {
153                index = Integer.parseInt(str.substring(0, pos));
154                value = Float.parseFloat(str.substring(pos + 1));
155            } catch (NumberFormatException e) {
156                throw new IllegalArgumentException("Failed to parse a feature value: " + str);
157            }
158            map.put(index, value);
159        }
160
161        return FVec.Transformer.fromMap(map);
162    }
163
164}