/*
 * Decompiled with CFR 0.152.
 */
package hivemall.tools.vector;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.StringUtils;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;

@Description(name="vector_dot", value="_FUNC_(array<NUMBER> x, array<NUMBER> y) - Performs vector dot product.", extended="SELECT vector_dot(array(1.0,2.0,3.0),array(2.0,3.0,4.0));\n20\n\nSELECT vector_dot(array(1.0,2.0,3.0),2);\n[2.0,4.0,6.0]")
@UDFType(deterministic=true, stateful=false)
public final class VectorDotUDF
extends GenericUDF {
    private Evaluator evaluator;

    public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length != 2) {
            throw new UDFArgumentLengthException("Expected 2 arguments, but got " + argOIs.length);
        }
        ObjectInspector argOI0 = argOIs[0];
        if (!HiveUtils.isNumberListOI(argOI0)) {
            throw new UDFArgumentException("Expected array<number> for the first argument: " + argOI0.getTypeName());
        }
        ListObjectInspector xListOI = HiveUtils.asListOI(argOI0);
        ObjectInspector argOI1 = argOIs[1];
        if (HiveUtils.isNumberListOI(argOI1)) {
            this.evaluator = new Dot2DVectors(xListOI, HiveUtils.asListOI(argOI1));
            return PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
        }
        if (HiveUtils.isNumberOI(argOI1)) {
            this.evaluator = new Multiply2D1D(xListOI, argOI1);
            return ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.javaDoubleObjectInspector);
        }
        throw new UDFArgumentException("Expected array<number> or number for the send argument: " + argOI1.getTypeName());
    }

    public Object evaluate(GenericUDF.DeferredObject[] args) throws HiveException {
        Object arg0 = args[0].get();
        Object arg1 = args[1].get();
        if (arg0 == null || arg1 == null) {
            return null;
        }
        return this.evaluator.dot(arg0, arg1);
    }

    public String getDisplayString(String[] args) {
        return "vector_dot(" + StringUtils.join(args, ',') + ")";
    }

    static final class Dot2DVectors
    implements Evaluator {
        private static final long serialVersionUID = -8783159823009951347L;
        private final ListObjectInspector xListOI;
        private final ListObjectInspector yListOI;
        private final PrimitiveObjectInspector xElemOI;
        private final PrimitiveObjectInspector yElemOI;

        Dot2DVectors(@Nonnull ListObjectInspector xListOI, @Nonnull ListObjectInspector yListOI) throws UDFArgumentTypeException {
            this.xListOI = xListOI;
            this.yListOI = yListOI;
            this.xElemOI = HiveUtils.asNumberOI(xListOI.getListElementObjectInspector());
            this.yElemOI = HiveUtils.asNumberOI(yListOI.getListElementObjectInspector());
        }

        @Override
        public Double dot(@Nonnull Object x, @Nonnull Object y) throws HiveException {
            int yLen;
            int xLen = this.xListOI.getListLength(x);
            if (xLen != (yLen = this.yListOI.getListLength(y))) {
                throw new HiveException("vector lengths do not match. x=" + this.xListOI.getList(x) + ", y=" + this.yListOI.getList(y));
            }
            double result = 0.0;
            for (int i = 0; i < xLen; ++i) {
                Object xi = this.xListOI.getListElement(x, i);
                Object yi = this.yListOI.getListElement(y, i);
                if (xi == null || yi == null) continue;
                double xd = PrimitiveObjectInspectorUtils.getDouble((Object)xi, (PrimitiveObjectInspector)this.xElemOI);
                double yd = PrimitiveObjectInspectorUtils.getDouble((Object)yi, (PrimitiveObjectInspector)this.yElemOI);
                double v = xd * yd;
                result += v;
            }
            return result;
        }
    }

    static final class Multiply2D1D
    implements Evaluator {
        private static final long serialVersionUID = -9090211833041797311L;
        private final ListObjectInspector xListOI;
        private final PrimitiveObjectInspector xElemOI;
        private final PrimitiveObjectInspector yOI;

        Multiply2D1D(@Nonnull ListObjectInspector xListOI, @Nonnull ObjectInspector yOI) throws UDFArgumentTypeException {
            this.xListOI = xListOI;
            this.xElemOI = HiveUtils.asNumberOI(xListOI.getListElementObjectInspector());
            this.yOI = HiveUtils.asNumberOI(yOI);
        }

        @Override
        public List<Double> dot(@Nonnull Object x, @Nonnull Object y) throws HiveException {
            double yd = PrimitiveObjectInspectorUtils.getDouble((Object)y, (PrimitiveObjectInspector)this.yOI);
            int xLen = this.xListOI.getListLength(x);
            Double[] arr = new Double[xLen];
            for (int i = 0; i < xLen; ++i) {
                Object xi = this.xListOI.getListElement(x, i);
                if (xi == null) continue;
                double xd = PrimitiveObjectInspectorUtils.getDouble((Object)xi, (PrimitiveObjectInspector)this.xElemOI);
                double v = xd * yd;
                arr[i] = v;
            }
            return Arrays.asList(arr);
        }
    }

    static interface Evaluator
    extends Serializable {
        @Nonnull
        public Object dot(@Nonnull Object var1, @Nonnull Object var2) throws HiveException;
    }
}

