/*
 * Decompiled with CFR 0.152.
 */
package hivemall.xgboost;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.util.FVec;
import hivemall.UDTFWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.xgboost.utils.XGBoostUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;

@Description(name="xgboost_predict", value="_FUNC_(PRIMITIVE rowid, array<string|double> features, string model_id, array<string> pred_model [, string options]) - Returns a prediction result as (string rowid, array<double> predicted)", extended="select\n  rowid, \n  array_avg(predicted) as predicted,\n  avg(predicted[0]) as predicted0\nfrom (\n  select\n    xgboost_predict(rowid, features, model_id, model) as (rowid, predicted)\n  from\n    xgb_model l\n    LEFT OUTER JOIN xgb_input r\n) t\ngroup by rowid;")
public class XGBoostOnlinePredictUDTF
extends UDTFWithOptions {
    private PrimitiveObjectInspector rowIdOI;
    private ListObjectInspector featureListOI;
    private boolean denseFeatures;
    @Nullable
    private PrimitiveObjectInspector featureElemOI;
    private StringObjectInspector modelIdOI;
    private StringObjectInspector modelOI;
    @Nullable
    private transient Map<String, Predictor> mapToModel;
    @Nonnull
    protected final transient Object[] _forwardObj;
    @Nullable
    protected transient List<DoubleWritable> _predictedCache;

    public XGBoostOnlinePredictUDTF() {
        this(new Object[2]);
    }

    protected XGBoostOnlinePredictUDTF(@Nonnull Object[] forwardObj) {
        this._forwardObj = forwardObj;
    }

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl = null;
        if (argOIs.length >= 5) {
            String rawArgs = HiveUtils.getConstString(argOIs, 4);
            cl = this.parseOptions(rawArgs);
        }
        return cl;
    }

    public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) throws UDFArgumentException {
        ListObjectInspector listOI;
        if (argOIs.length != 4 && argOIs.length != 5) {
            this.showHelp("Invalid argment length=" + argOIs.length);
        }
        this.processOptions(argOIs);
        this.rowIdOI = HiveUtils.asPrimitiveObjectInspector(argOIs, 0);
        this.featureListOI = listOI = HiveUtils.asListOI(argOIs, 1);
        ObjectInspector elemOI = listOI.getListElementObjectInspector();
        if (HiveUtils.isNumberOI(elemOI)) {
            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
            this.denseFeatures = true;
        } else if (HiveUtils.isStringOI(elemOI)) {
            this.denseFeatures = false;
        } else {
            throw new UDFArgumentException("Expected array<string|double> for the 2nd argment but got an unexpected features type: " + listOI.getTypeName());
        }
        this.modelIdOI = HiveUtils.asStringOI(argOIs, 2);
        this.modelOI = HiveUtils.asStringOI(argOIs, 3);
        return this.getReturnOI(this.rowIdOI);
    }

    @Nonnull
    protected StructObjectInspector getReturnOI(@Nonnull PrimitiveObjectInspector rowIdOI) {
        ArrayList<String> fieldNames = new ArrayList<String>(2);
        ArrayList<Object> fieldOIs = new ArrayList<Object>(2);
        fieldNames.add("rowid");
        fieldOIs.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector((PrimitiveObjectInspector.PrimitiveCategory)rowIdOI.getPrimitiveCategory()));
        fieldNames.add("predicted");
        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    public void process(Object[] args) throws HiveException {
        if (this.mapToModel == null) {
            this.mapToModel = new HashMap<String, Predictor>();
        }
        if (args[1] == null) {
            return;
        }
        String modelId = PrimitiveObjectInspectorUtils.getString((Object)XGBoostOnlinePredictUDTF.nonNullArgument(args, 2), (PrimitiveObjectInspector)this.modelIdOI);
        Predictor model = this.mapToModel.get(modelId);
        if (model == null) {
            Text arg3 = this.modelOI.getPrimitiveWritableObject(XGBoostOnlinePredictUDTF.nonNullArgument(args, 3));
            model = XGBoostUtils.loadPredictor(arg3);
            this.mapToModel.put(modelId, model);
        }
        Writable rowId = HiveUtils.copyToWritable(XGBoostOnlinePredictUDTF.nonNullArgument(args, 0), this.rowIdOI);
        FVec features = this.denseFeatures ? this.parseDenseFeatures(args[1]) : XGBoostOnlinePredictUDTF.parseSparseFeatures(this.featureListOI.getList(args[1]));
        this.predictAndForward(model, rowId, features);
    }

    @Nonnull
    private FVec parseDenseFeatures(@Nonnull Object argObj) throws UDFArgumentException {
        int length = this.featureListOI.getListLength(argObj);
        double[] values = new double[length];
        for (int i = 0; i < length; ++i) {
            Object o = this.featureListOI.getListElement(argObj, i);
            double v = o == null ? Double.NaN : PrimitiveObjectInspectorUtils.getDouble((Object)o, (PrimitiveObjectInspector)this.featureElemOI);
            values[i] = v;
        }
        return FVec.Transformer.fromArray(values, false);
    }

    @Nonnull
    private static FVec parseSparseFeatures(@Nonnull List<?> featureList) throws UDFArgumentException {
        HashMap<Integer, Double> map = new HashMap<Integer, Double>((int)((double)featureList.size() * 1.5));
        for (Object f : featureList) {
            double value;
            int index;
            if (f == null) continue;
            String str = f.toString();
            int pos = str.indexOf(58);
            if (pos < 1) {
                throw new UDFArgumentException("Invalid feature format: " + str);
            }
            try {
                index = Integer.parseInt(str.substring(0, pos));
                value = Double.parseDouble(str.substring(pos + 1));
            }
            catch (NumberFormatException e) {
                throw new UDFArgumentException("Failed to parse a feature value: " + str);
            }
            map.put(index, value);
        }
        return FVec.Transformer.fromMap(map);
    }

    private void predictAndForward(@Nonnull Predictor model, @Nonnull Writable rowId, @Nonnull FVec features) throws HiveException {
        double[] predicted = model.predict(features);
        this.forwardPredicted(rowId, predicted);
    }

    protected void forwardPredicted(@Nonnull Writable rowId, @Nonnull double[] predicted) throws HiveException {
        List<DoubleWritable> list = WritableUtils.toWritableList(predicted, this._predictedCache);
        this._predictedCache = list;
        Object[] forwardObj = this._forwardObj;
        forwardObj[0] = rowId;
        forwardObj[1] = list;
        this.forward(forwardObj);
    }

    public void close() throws HiveException {
        this.mapToModel = null;
    }
}

