/*
 * Decompiled with CFR 0.152.
 */
package com.github.aaronshan.functions.math;

import java.util.Map;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

@Description(name="cosine_similarity", value="_FUNC_(map(varchar,double), map(varchar,double)) - cosine similarity between the given sparse vectors.", extended="Example:\n > select _FUNC_(map(varchar,double), map(varchar,double)) from src;")
public class UDFMathCosineSimilarity
extends GenericUDF {
    private static final int ARG_COUNT = 2;
    private transient MapObjectInspector leftMapOI;
    private transient MapObjectInspector rightMapOI;

    public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
        if (arguments.length != 2) {
            throw new UDFArgumentLengthException("The function cosine_similarity(map, map) takes exactly 2 arguments.");
        }
        for (int i = 0; i < 2; ++i) {
            if (arguments[i].getCategory().equals((Object)ObjectInspector.Category.MAP)) continue;
            throw new UDFArgumentTypeException(i, "\"map\" expected at function cosine_similarity, but \"" + arguments[i].getTypeName() + "\" is found");
        }
        this.leftMapOI = (MapObjectInspector)arguments[0];
        this.rightMapOI = (MapObjectInspector)arguments[1];
        ObjectInspector leftMapKeyOI = this.leftMapOI.getMapKeyObjectInspector();
        ObjectInspector leftMapValueOI = this.leftMapOI.getMapValueObjectInspector();
        ObjectInspector rightMapKeyOI = this.rightMapOI.getMapKeyObjectInspector();
        ObjectInspector rightMapValueOI = this.rightMapOI.getMapValueObjectInspector();
        if (!ObjectInspectorUtils.compareTypes((ObjectInspector)leftMapKeyOI, (ObjectInspector)rightMapKeyOI)) {
            throw new UDFArgumentTypeException(1, "\"" + leftMapKeyOI.getTypeName() + "\" expected at function cosine_similarity key, but \"" + rightMapKeyOI.getTypeName() + "\" is found");
        }
        if (!ObjectInspectorUtils.compareTypes((ObjectInspector)PrimitiveObjectInspectorFactory.javaStringObjectInspector, (ObjectInspector)leftMapKeyOI)) {
            throw new UDFArgumentTypeException(1, "\"" + PrimitiveObjectInspectorFactory.javaStringObjectInspector.getTypeName() + "\" expected at function cosine_similarity key, but \"" + leftMapKeyOI.getTypeName() + "\" is found");
        }
        if (!ObjectInspectorUtils.compareTypes((ObjectInspector)leftMapValueOI, (ObjectInspector)rightMapValueOI)) {
            throw new UDFArgumentTypeException(1, "\"" + leftMapValueOI.getTypeName() + "\" expected at function cosine_similarity value, but \"" + rightMapValueOI.getTypeName() + "\" is found");
        }
        if (!ObjectInspectorUtils.compareTypes((ObjectInspector)PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, (ObjectInspector)leftMapValueOI)) {
            throw new UDFArgumentTypeException(1, "\"" + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector.getTypeName() + "\" expected at function cosine_similarity value, but \"" + leftMapValueOI.getTypeName() + "\" is found");
        }
        return ObjectInspectorFactory.getStandardMapObjectInspector((ObjectInspector)leftMapKeyOI, (ObjectInspector)leftMapValueOI);
    }

    public Object evaluate(GenericUDF.DeferredObject[] arguments) throws HiveException {
        Object leftMapObj = arguments[0].get();
        Object rightMapObj = arguments[1].get();
        if (leftMapObj == null || rightMapObj == null) {
            return null;
        }
        Map leftMap = this.leftMapOI.getMap(leftMapObj);
        Map rightMap = this.leftMapOI.getMap(rightMapObj);
        Double normLeftMap = this.mapL2Norm(leftMap);
        Double normRightMap = this.mapL2Norm(rightMap);
        if (normLeftMap == null || normRightMap == null) {
            return null;
        }
        double dotProduct = this.mapDotProduct(leftMap, rightMap);
        return new DoubleWritable(dotProduct / (normLeftMap * normRightMap));
    }

    private double mapDotProduct(Map<?, ?> leftMap, Map<?, ?> rightMap) {
        double result = 0.0;
        for (Map.Entry<?, ?> entry : rightMap.entrySet()) {
            if (!leftMap.containsKey(entry.getKey())) continue;
            Double leftValue = (Double)leftMap.get(entry.getKey());
            Double rightValue = (Double)entry.getValue();
            result += leftValue * rightValue;
        }
        return result;
    }

    private Double mapL2Norm(Map<?, ?> map) {
        double norm = 0.0;
        for (Map.Entry<?, ?> entry : map.entrySet()) {
            if (entry.getValue() == null) {
                return null;
            }
            Double value = (Double)entry.getValue();
            norm += value * value;
        }
        return Math.sqrt(norm);
    }

    public String getDisplayString(String[] strings) {
        assert (strings.length == 2);
        return "cosine_similarity(" + strings[0] + ", " + strings[1] + ")";
    }
}

