/*
 * Decompiled with CFR 0.152.
 */
package hivemall.tools.array;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Preconditions;
import java.io.IOException;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

@Description(name="select_k_best", value="_FUNC_(array<number> array, const array<number> importance, const int k) - Returns selected top-k elements as array<double>")
@UDFType(deterministic=true, stateful=false)
public final class SelectKBestUDF
extends GenericUDF {
    private ListObjectInspector featuresOI;
    private PrimitiveObjectInspector featureOI;
    private ListObjectInspector importanceListOI;
    private PrimitiveObjectInspector importanceElemOI;
    private int _k;
    private List<DoubleWritable> _result;
    private int[] _topKIndices;

    public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
        if (OIs.length != 3) {
            throw new UDFArgumentLengthException("Specify three arguments: " + OIs.length);
        }
        if (!HiveUtils.isNumberListOI(OIs[0])) {
            throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but " + OIs[0].getTypeName() + " was passed as `features`");
        }
        if (!HiveUtils.isNumberListOI(OIs[1])) {
            throw new UDFArgumentTypeException(1, "Only array<number> type argument is acceptable but " + OIs[1].getTypeName() + " was passed as `importance_list`");
        }
        if (!HiveUtils.isIntegerOI(OIs[2])) {
            throw new UDFArgumentTypeException(2, "Only int type argument is acceptable but " + OIs[2].getTypeName() + " was passed as `k`");
        }
        this.featuresOI = HiveUtils.asListOI(OIs[0]);
        this.featureOI = HiveUtils.asDoubleCompatibleOI(this.featuresOI.getListElementObjectInspector());
        this.importanceListOI = HiveUtils.asListOI(OIs[1]);
        this.importanceElemOI = HiveUtils.asDoubleCompatibleOI(this.importanceListOI.getListElementObjectInspector());
        this._k = HiveUtils.getConstInt(OIs[2]);
        Preconditions.checkArgument(this._k >= 1, UDFArgumentException.class);
        ArrayList<DoubleWritable> result = new ArrayList<DoubleWritable>(this._k);
        for (int i = 0; i < this._k; ++i) {
            result.add(new DoubleWritable());
        }
        this._result = result;
        return ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
    }

    public List<DoubleWritable> evaluate(GenericUDF.DeferredObject[] dObj) throws HiveException {
        int i;
        double[] features = HiveUtils.asDoubleArray(dObj[0].get(), this.featuresOI, this.featureOI);
        double[] importanceList = HiveUtils.asDoubleArray(dObj[1].get(), this.importanceListOI, this.importanceElemOI);
        Preconditions.checkNotNull(features, UDFArgumentException.class);
        Preconditions.checkNotNull(importanceList, UDFArgumentException.class);
        Preconditions.checkArgument(features.length == importanceList.length, UDFArgumentException.class);
        Preconditions.checkArgument(features.length >= this._k, UDFArgumentException.class);
        int[] topKIndices = this._topKIndices;
        if (topKIndices == null) {
            ArrayList<AbstractMap.SimpleEntry<Integer, Double>> list = new ArrayList<AbstractMap.SimpleEntry<Integer, Double>>();
            for (i = 0; i < importanceList.length; ++i) {
                list.add(new AbstractMap.SimpleEntry<Integer, Double>(i, importanceList[i]));
            }
            Collections.sort(list, new Comparator<Map.Entry<Integer, Double>>(){

                @Override
                public int compare(Map.Entry<Integer, Double> o1, Map.Entry<Integer, Double> o2) {
                    return o1.getValue() > o2.getValue() ? -1 : 1;
                }
            });
            topKIndices = new int[this._k];
            for (i = 0; i < topKIndices.length; ++i) {
                topKIndices[i] = (Integer)((Map.Entry)list.get(i)).getKey();
            }
            this._topKIndices = topKIndices;
        }
        List<DoubleWritable> result = this._result;
        for (i = 0; i < topKIndices.length; ++i) {
            int idx = topKIndices[i];
            DoubleWritable d = result.get(i);
            double f = features[idx];
            d.set(f);
        }
        return result;
    }

    public void close() throws IOException {
        this._result = null;
        this._topKIndices = null;
    }

    public String getDisplayString(String[] children) {
        StringBuilder sb = new StringBuilder();
        sb.append("select_k_best");
        sb.append("(");
        if (children.length > 0) {
            sb.append(children[0]);
            for (int i = 1; i < children.length; ++i) {
                sb.append(", ");
                sb.append(children[i]);
            }
        }
        sb.append(")");
        return sb.toString();
    }
}

