/*
 * Decompiled with CFR 0.152.
 */
package hivemall.knn.distance;

import hivemall.knn.distance.HammingDistanceUDF;
import java.math.BigInteger;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.io.FloatWritable;

@Description(name="jaccard_distance", value="_FUNC_(integer A, integer B [,int k=128]) - Returns Jaccard distance between A and B", extended="select \n  jaccard_distance(0,3) as c1, \n  jaccard_distance(\"0\",\"3\") as c2, -- 0=0x00, 0=0x11\n  jaccard_distance(0,4) as c3\n;\n\nc1      c2      c3\n0.03125 0.03125 0.015625")
@UDFType(deterministic=true, stateful=false)
public final class JaccardDistanceUDF
extends UDF {
    private final Set<Object> union = new HashSet<Object>();
    private final Set<Object> intersect = new HashSet<Object>();

    public FloatWritable evaluate(long a, long b) {
        return this.evaluate(a, b, 128);
    }

    public FloatWritable evaluate(long a, long b, int k) {
        int countMatches = k - HammingDistanceUDF.hammingDistance(a, b);
        float jaccard = (float)countMatches / (float)k;
        float pseudoJaccard = 2.0f * (jaccard - 0.5f);
        return new FloatWritable(1.0f - pseudoJaccard);
    }

    public FloatWritable evaluate(String a, String b) {
        return this.evaluate(a, b, 128);
    }

    public FloatWritable evaluate(String a, String b, int k) {
        BigInteger ai = new BigInteger(a);
        BigInteger bi = new BigInteger(b);
        int countMatches = k - HammingDistanceUDF.hammingDistance(ai, bi);
        float jaccard = (float)countMatches / (float)k;
        float pseudoJaccard = 2.0f * (jaccard - 0.5f);
        return new FloatWritable(1.0f - pseudoJaccard);
    }

    public FloatWritable evaluate(List<String> a, List<String> b) {
        if (a == null && b == null) {
            return new FloatWritable(0.0f);
        }
        if (a == null || b == null) {
            return new FloatWritable(1.0f);
        }
        int asize = a.size();
        int bsize = b.size();
        if (asize == 0 && bsize == 0) {
            return new FloatWritable(0.0f);
        }
        if (asize == 0 || bsize == 0) {
            return new FloatWritable(1.0f);
        }
        this.union.addAll(a);
        this.union.addAll(b);
        float unionSize = this.union.size();
        this.union.clear();
        this.intersect.addAll(a);
        this.intersect.retainAll(b);
        float intersectSize = this.intersect.size();
        this.intersect.clear();
        float j = intersectSize / unionSize;
        return new FloatWritable(1.0f - j);
    }
}

