/*
 * Decompiled with CFR 0.152.
 */
package hivemall.ftvec.trans;

import hivemall.utils.hadoop.HiveUtils;
import java.util.ArrayList;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;

@Description(name="binarize_label", value="_FUNC_(int/long positive, int/long negative, ...) - Returns positive/negative records that are represented as (..., int label) where label is 0 or 1")
@UDFType(deterministic=true, stateful=false)
public final class BinarizeLabelUDTF
extends GenericUDTF {
    private PrimitiveObjectInspector positiveOI;
    private PrimitiveObjectInspector negativeOI;
    private Object[] positiveObjs;
    private Object[] negativeObjs;

    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length < 3) {
            throw new UDFArgumentException("binarize_label(int/long positive, int/long negative, *) takes at least three arguments");
        }
        this.positiveOI = HiveUtils.asIntCompatibleOI(argOIs[0]);
        this.negativeOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
        this.positiveObjs = new Object[argOIs.length - 1];
        this.positiveObjs[this.positiveObjs.length - 1] = 1;
        this.negativeObjs = new Object[argOIs.length - 1];
        this.negativeObjs[this.negativeObjs.length - 1] = 0;
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<Object> fieldOIs = new ArrayList<Object>();
        for (int i = 2; i < argOIs.length; ++i) {
            fieldNames.add("c" + (i - 2));
            fieldOIs.add(argOIs[i]);
        }
        fieldNames.add("c" + (argOIs.length - 2));
        fieldOIs.add(PrimitiveObjectInspectorFactory.javaIntObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    public void process(Object[] args) throws HiveException {
        Object[] positiveObjs = this.positiveObjs;
        int last = positiveObjs.length - 1;
        for (int i = 0; i < last; ++i) {
            positiveObjs[i] = args[i + 2];
        }
        int positive = PrimitiveObjectInspectorUtils.getInt((Object)args[0], (PrimitiveObjectInspector)this.positiveOI);
        for (int i = 0; i < positive; ++i) {
            this.forward(positiveObjs);
        }
        Object[] negativeObjs = this.negativeObjs;
        int last2 = negativeObjs.length - 1;
        for (int i = 0; i < last2; ++i) {
            negativeObjs[i] = args[i + 2];
        }
        int negative = PrimitiveObjectInspectorUtils.getInt((Object)args[1], (PrimitiveObjectInspector)this.negativeOI);
        for (int i = 0; i < negative; ++i) {
            this.forward(negativeObjs);
        }
    }

    public void close() throws HiveException {
        this.positiveObjs = null;
        this.negativeObjs = null;
    }
}

