/*
 * Decompiled with CFR 0.152.
 */
package hivemall.evaluation;

import hivemall.evaluation.BinaryResponsesMeasures;
import hivemall.utils.hadoop.HiveUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;

@Description(name="mrr", value="_FUNC_(array rankItems, array correctItems [, const int recommendSize = rankItems.size]) - Returns MRR")
public final class MRRUDAF
extends AbstractGenericUDAFResolver {
    public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException {
        if (typeInfo.length != 2 && typeInfo.length != 3) {
            throw new UDFArgumentTypeException(typeInfo.length - 1, "_FUNC_ takes two or three arguments");
        }
        ListTypeInfo arg1type = HiveUtils.asListTypeInfo(typeInfo[0]);
        if (!HiveUtils.isPrimitiveTypeInfo(arg1type.getListElementTypeInfo())) {
            throw new UDFArgumentTypeException(0, "The first argument `array rankItems` is invalid form: " + typeInfo[0]);
        }
        ListTypeInfo arg2type = HiveUtils.asListTypeInfo(typeInfo[1]);
        if (!HiveUtils.isPrimitiveTypeInfo(arg2type.getListElementTypeInfo())) {
            throw new UDFArgumentTypeException(1, "The second argument `array correctItems` is invalid form: " + typeInfo[1]);
        }
        return new Evaluator();
    }

    public static class MRRAggregationBuffer
    extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        double sum;
        long count;

        void reset() {
            this.sum = 0.0;
            this.count = 0L;
        }

        void merge(double o_sum, long o_count) {
            this.sum += o_sum;
            this.count += o_count;
        }

        double get() {
            if (this.count == 0L) {
                return 0.0;
            }
            return this.sum / (double)this.count;
        }

        void iterate(@Nonnull List<?> recommendList, @Nonnull List<?> truthList, @Nonnull int recommendSize) {
            this.sum += BinaryResponsesMeasures.ReciprocalRank(recommendList, truthList, recommendSize);
            ++this.count;
        }
    }

    public static class Evaluator
    extends GenericUDAFEvaluator {
        private ListObjectInspector recommendListOI;
        private ListObjectInspector truthListOI;
        private PrimitiveObjectInspector recommendSizeOI;
        private StructObjectInspector internalMergeOI;
        private StructField countField;
        private StructField sumField;

        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] parameters) throws HiveException {
            assert (parameters.length >= 1 && parameters.length <= 3) : parameters.length;
            super.init(mode, parameters);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.recommendListOI = (ListObjectInspector)parameters[0];
                this.truthListOI = (ListObjectInspector)parameters[1];
                if (parameters.length == 3) {
                    this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]);
                }
            } else {
                StructObjectInspector soi;
                this.internalMergeOI = soi = (StructObjectInspector)parameters[0];
                this.countField = soi.getStructFieldRef("count");
                this.sumField = soi.getStructFieldRef("sum");
            }
            Object outputOI = mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2 ? Evaluator.internalMergeOI() : PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
            return outputOI;
        }

        private static StructObjectInspector internalMergeOI() {
            ArrayList<String> fieldNames = new ArrayList<String>();
            ArrayList<Object> fieldOIs = new ArrayList<Object>();
            fieldNames.add("sum");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            fieldNames.add("count");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
            return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
        }

        public MRRAggregationBuffer getNewAggregationBuffer() throws HiveException {
            MRRAggregationBuffer myAggr = new MRRAggregationBuffer();
            this.reset((GenericUDAFEvaluator.AggregationBuffer)myAggr);
            return myAggr;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            MRRAggregationBuffer myAggr = (MRRAggregationBuffer)agg;
            myAggr.reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer agg, Object[] parameters) throws HiveException {
            List truthList;
            MRRAggregationBuffer myAggr = (MRRAggregationBuffer)agg;
            List recommendList = this.recommendListOI.getList(parameters[0]);
            if (recommendList == null) {
                recommendList = Collections.emptyList();
            }
            if ((truthList = this.truthListOI.getList(parameters[1])) == null) {
                return;
            }
            int recommendSize = recommendList.size();
            if (parameters.length == 3 && (recommendSize = PrimitiveObjectInspectorUtils.getInt((Object)parameters[2], (PrimitiveObjectInspector)this.recommendSizeOI)) < 0) {
                throw new UDFArgumentException("The third argument `int recommendSize` must be in greater than or equals to 0: " + recommendSize);
            }
            myAggr.iterate(recommendList, truthList, recommendSize);
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            MRRAggregationBuffer myAggr = (MRRAggregationBuffer)agg;
            Object[] partialResult = new Object[]{new DoubleWritable(myAggr.sum), new LongWritable(myAggr.count)};
            return partialResult;
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer agg, Object partial) throws HiveException {
            if (partial == null) {
                return;
            }
            Object sumObj = this.internalMergeOI.getStructFieldData(partial, this.sumField);
            Object countObj = this.internalMergeOI.getStructFieldData(partial, this.countField);
            double sum = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(sumObj);
            long count = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(countObj);
            MRRAggregationBuffer myAggr = (MRRAggregationBuffer)agg;
            myAggr.merge(sum, count);
        }

        public DoubleWritable terminate(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            MRRAggregationBuffer myAggr = (MRRAggregationBuffer)agg;
            double result = myAggr.get();
            return new DoubleWritable(result);
        }
    }
}

