/*
 * Decompiled with CFR 0.152.
 */
package hivemall.tools.vector;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.StringUtils;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;

@Description(name="vector_add", value="_FUNC_(array<NUMBER> x, array<NUMBER> y) - Perform vector ADD operation.", extended="SELECT vector_add(array(1.0,2.0,3.0), array(2, 3, 4));\n[3.0,5.0,7.0]")
@UDFType(deterministic=true, stateful=false)
public final class VectorAddUDF
extends GenericUDF {
    private ListObjectInspector xOI;
    private ListObjectInspector yOI;
    private PrimitiveObjectInspector xElemOI;
    private PrimitiveObjectInspector yElemOI;
    private boolean floatingPoints;

    public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length != 2) {
            throw new UDFArgumentLengthException("Expected 2 arguments, but got " + argOIs.length);
        }
        this.xOI = HiveUtils.asListOI(argOIs[0]);
        this.yOI = HiveUtils.asListOI(argOIs[1]);
        this.xElemOI = HiveUtils.asNumberOI(this.xOI.getListElementObjectInspector());
        this.yElemOI = HiveUtils.asNumberOI(this.yOI.getListElementObjectInspector());
        if (HiveUtils.isIntegerOI((ObjectInspector)this.xElemOI) && HiveUtils.isIntegerOI((ObjectInspector)this.yElemOI)) {
            this.floatingPoints = false;
            return ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.javaLongObjectInspector);
        }
        this.floatingPoints = true;
        return ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.javaDoubleObjectInspector);
    }

    @Nullable
    public List<?> evaluate(@Nonnull GenericUDF.DeferredObject[] args) throws HiveException {
        int yLen;
        Object arg0 = args[0].get();
        Object arg1 = args[1].get();
        if (arg0 == null || arg1 == null) {
            return null;
        }
        int xLen = this.xOI.getListLength(arg0);
        if (xLen != (yLen = this.yOI.getListLength(arg1))) {
            throw new HiveException("vector lengths do not match. x=" + this.xOI.getList(arg0) + ", y=" + this.yOI.getList(arg1));
        }
        if (this.floatingPoints) {
            return this.evaluateDouble(arg0, arg1, xLen);
        }
        return this.evaluateLong(arg0, arg1, xLen);
    }

    @Nonnull
    private List<Double> evaluateDouble(@Nonnull Object vecX, @Nonnull Object vecY, @Nonnegative int size) {
        Double[] arr = new Double[size];
        for (int i = 0; i < size; ++i) {
            Object x = this.xOI.getListElement(vecX, i);
            Object y = this.yOI.getListElement(vecY, i);
            if (x == null || y == null) continue;
            double xd = PrimitiveObjectInspectorUtils.getDouble((Object)x, (PrimitiveObjectInspector)this.xElemOI);
            double yd = PrimitiveObjectInspectorUtils.getDouble((Object)y, (PrimitiveObjectInspector)this.yElemOI);
            double v = xd + yd;
            arr[i] = v;
        }
        return Arrays.asList(arr);
    }

    @Nonnull
    private List<Long> evaluateLong(@Nonnull Object vecX, @Nonnull Object vecY, @Nonnegative int size) {
        Long[] arr = new Long[size];
        for (int i = 0; i < size; ++i) {
            Object x = this.xOI.getListElement(vecX, i);
            Object y = this.yOI.getListElement(vecY, i);
            if (x == null || y == null) continue;
            long xd = PrimitiveObjectInspectorUtils.getLong((Object)x, (PrimitiveObjectInspector)this.xElemOI);
            long yd = PrimitiveObjectInspectorUtils.getLong((Object)y, (PrimitiveObjectInspector)this.yElemOI);
            long v = xd + yd;
            arr[i] = v;
        }
        return Arrays.asList(arr);
    }

    public String getDisplayString(String[] args) {
        return "vector_add(" + StringUtils.join(args, ',') + ")";
    }
}

