/*
 * Decompiled with CFR 0.152.
 */
package hivemall.knn.lsh;

import hivemall.UDTFWithOptions;
import hivemall.model.FeatureValue;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hashing.HashFunction;
import hivemall.utils.hashing.HashFunctionFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

@Description(name="minhash", value="_FUNC_(ANY item, array<int|bigint|string> features [, constant string options]) - Returns n different k-depth signatures (i.e., clusterid) for each item <clusterid, item>")
@UDFType(deterministic=true, stateful=false)
public final class MinHashUDTF
extends UDTFWithOptions {
    private ObjectInspector itemOI;
    private ListObjectInspector featureListOI;
    private boolean parseFeature;
    private Object[] forwardObjs;
    private int num_hashes = 5;
    private int num_keygroups = 2;
    private HashFunction[] hashFuncs;

    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length < 2) {
            throw new UDFArgumentException("_FUNC_ takes more than 2 arguments: ANY item, Array<Int|BigInt|Text> features [, constant String options]");
        }
        this.itemOI = argOIs[0];
        this.featureListOI = (ListObjectInspector)argOIs[1];
        ObjectInspector featureRawOI = this.featureListOI.getListElementObjectInspector();
        String keyTypeName = featureRawOI.getTypeName();
        if (!("string".equals(keyTypeName) || "int".equals(keyTypeName) || "bigint".equals(keyTypeName))) {
            throw new UDFArgumentTypeException(0, "1st argument must be Map of key type [Int|BitInt|Text]: " + keyTypeName);
        }
        this.parseFeature = "string".equals(keyTypeName);
        this.forwardObjs = new Object[2];
        this.processOptions(argOIs);
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<Object> fieldOIs = new ArrayList<Object>();
        fieldNames.add("clusterid");
        fieldOIs.add(PrimitiveObjectInspectorFactory.javaIntObjectInspector);
        fieldNames.add("item");
        fieldOIs.add(this.itemOI);
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("n", "hashes", true, "Generate N sets of minhash values for each row (DEFAULT: 5)");
        opts.addOption("k", "keygroups", true, "Use K minhash value (DEFAULT: 2)");
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl = null;
        if (argOIs.length >= 3) {
            String numKeygroups;
            String rawArgs = HiveUtils.getConstString(argOIs[2]);
            cl = this.parseOptions(rawArgs);
            String numHashes = cl.getOptionValue("hashes");
            if (numHashes != null) {
                this.num_hashes = Integer.parseInt(numHashes);
            }
            if ((numKeygroups = cl.getOptionValue("keygroups")) != null) {
                this.num_keygroups = Integer.parseInt(numKeygroups);
            }
        }
        this.hashFuncs = HashFunctionFactory.create(this.num_hashes);
        return cl;
    }

    public void process(Object[] args) throws HiveException {
        Object[] forwardObjs = this.forwardObjs;
        forwardObjs[1] = args[0];
        List features = this.featureListOI.getList(args[1]);
        ObjectInspector featureInspector = this.featureListOI.getListElementObjectInspector();
        List<FeatureValue> ftvec = this.parseFeatures(features, featureInspector, this.parseFeature);
        this.computeAndForwardSignatures(ftvec, forwardObjs);
    }

    private void computeAndForwardSignatures(List<FeatureValue> features, Object[] forwardObjs) throws HiveException {
        PriorityQueue<Integer> minhashes = new PriorityQueue<Integer>();
        for (int i = 0; i < this.num_hashes; ++i) {
            float weightedMinHashValues = Float.MAX_VALUE;
            for (FeatureValue fv : features) {
                float w;
                Object f = fv.getFeature();
                int hashIndex = Math.abs(this.hashFuncs[i].hash(f));
                float hashValue = MinHashUDTF.calcWeightedHashValue(hashIndex, w = fv.getValueAsFloat());
                if (!(hashValue < weightedMinHashValues)) continue;
                weightedMinHashValues = hashValue;
                minhashes.offer(hashIndex);
            }
            forwardObjs[0] = MinHashUDTF.getSignature(minhashes, this.num_keygroups);
            this.forward(forwardObjs);
            minhashes.clear();
        }
    }

    private static float calcWeightedHashValue(int hashIndex, float w) throws HiveException {
        if (w < 0.0f) {
            throw new HiveException("Non-negative value is not accepted for a feature weight");
        }
        if (w == 0.0f) {
            return Float.MAX_VALUE;
        }
        return (float)hashIndex / w;
    }

    private static int getSignature(PriorityQueue<Integer> candidates, int keyGroups) {
        int numCandidates = candidates.size();
        if (numCandidates == 0) {
            return 0;
        }
        int size = Math.min(numCandidates, keyGroups);
        int result = 1;
        for (int i = 0; i < size; ++i) {
            int nextmin = candidates.poll();
            result = 31 * result + nextmin;
        }
        return result & Integer.MAX_VALUE;
    }

    public void close() throws HiveException {
    }
}

