/*
 * Decompiled with CFR 0.152.
 */
package hivemall.xgboost.utils;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.util.FVec;
import hivemall.utils.io.FastByteArrayInputStream;
import hivemall.utils.io.IOUtils;
import hivemall.xgboost.XGBoostBatchPredictUDTF;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.io.Text;

public final class XGBoostUtils {
    private XGBoostUtils() {
    }

    @Nonnull
    public static String getVersion() throws HiveException {
        Properties props2 = new Properties();
        try (InputStream versionResourceFile = Thread.currentThread().getContextClassLoader().getResourceAsStream("xgboost4j-version.properties");){
            props2.load(versionResourceFile);
        }
        catch (IOException e) {
            throw new HiveException("Failed to load xgboost4j-version.properties", (Throwable)e);
        }
        return props2.getProperty("version", "<unknown>");
    }

    @Nonnull
    public static DMatrix createDMatrix(@Nonnull List<XGBoostBatchPredictUDTF.LabeledPointWithRowId> data) throws XGBoostError {
        ArrayList<XGBoostBatchPredictUDTF.LabeledPointWithRowId> points = new ArrayList<XGBoostBatchPredictUDTF.LabeledPointWithRowId>(data.size());
        for (XGBoostBatchPredictUDTF.LabeledPointWithRowId d : data) {
            points.add(d);
        }
        return new DMatrix(points.iterator(), "");
    }

    @Nonnull
    public static Booster createBooster(@Nonnull DMatrix matrix, @Nonnull Map<String, Object> params) throws NoSuchMethodException, XGBoostError, IllegalAccessException, InvocationTargetException, InstantiationException {
        Class[] args = new Class[]{Map.class, DMatrix[].class};
        Constructor ctor = Booster.class.getDeclaredConstructor(args);
        ctor.setAccessible(true);
        return (Booster)ctor.newInstance(params, new DMatrix[]{matrix});
    }

    public static void close(@Nullable DMatrix matrix) {
        if (matrix == null) {
            return;
        }
        try {
            matrix.dispose();
        }
        catch (Throwable throwable) {
            // empty catch block
        }
    }

    public static void close(@Nullable Booster booster) {
        if (booster == null) {
            return;
        }
        try {
            booster.dispose();
        }
        catch (Throwable throwable) {
            // empty catch block
        }
    }

    @Nonnull
    public static Text serializeBooster(@Nonnull Booster booster) throws HiveException {
        try {
            byte[] b = IOUtils.toCompressedText(booster.toByteArray());
            return new Text(b);
        }
        catch (Throwable e) {
            throw new HiveException("Failed to serialize a booster", e);
        }
    }

    @Nonnull
    public static Booster deserializeBooster(@Nonnull Text model) throws HiveException {
        try {
            byte[] b = IOUtils.fromCompressedText(model.getBytes(), model.getLength());
            return XGBoost.loadModel(new FastByteArrayInputStream(b));
        }
        catch (Throwable e) {
            throw new HiveException("Failed to deserialize a booster", e);
        }
    }

    @Nonnull
    public static Predictor loadPredictor(@Nonnull Text model) throws HiveException {
        try {
            byte[] b = IOUtils.fromCompressedText(model.getBytes(), model.getLength());
            return new Predictor(new FastByteArrayInputStream(b));
        }
        catch (Throwable e) {
            throw new HiveException("Failed to create a predictor", e);
        }
    }

    @Nonnull
    public static FVec parseRowAsFVec(@Nonnull String[] row, int start2, int end) {
        HashMap<Integer, Float> map = new HashMap<Integer, Float>((int)((double)row.length * 1.5));
        for (int i = start2; i < end; ++i) {
            float value;
            int index;
            String f = row[i];
            if (f == null) continue;
            String str = f.toString();
            int pos = str.indexOf(58);
            if (pos < 1) {
                throw new IllegalArgumentException("Invalid feature format: " + str);
            }
            try {
                index = Integer.parseInt(str.substring(0, pos));
                value = Float.parseFloat(str.substring(pos + 1));
            }
            catch (NumberFormatException e) {
                throw new IllegalArgumentException("Failed to parse a feature value: " + str);
            }
            map.put(index, Float.valueOf(value));
        }
        return FVec.Transformer.fromMap(map);
    }
}

