/*
 * Decompiled with CFR 0.152.
 */
package hivemall.factorization.fm;

import hivemall.utils.hadoop.HiveUtils;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;

@Description(name="ffm_predict", value="_FUNC_(float Wi, array<float> Vifj, array<float> Vjfi, float Xi, float Xj) - Returns a prediction value in Double")
public final class FFMPredictGenericUDAF
extends AbstractGenericUDAFResolver {
    private FFMPredictGenericUDAF() {
    }

    public Evaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException {
        if (typeInfo.length != 5) {
            throw new UDFArgumentLengthException("Expected argument length is 5 but given argument length was " + typeInfo.length);
        }
        if (!HiveUtils.isNumberTypeInfo(typeInfo[0])) {
            throw new UDFArgumentTypeException(0, "Number type is expected for the first argument Wi: " + typeInfo[0].getTypeName());
        }
        if (typeInfo[1].getCategory() != ObjectInspector.Category.LIST) {
            throw new UDFArgumentTypeException(1, "List type is expected for the second argument Vifj: " + typeInfo[1].getTypeName());
        }
        if (typeInfo[2].getCategory() != ObjectInspector.Category.LIST) {
            throw new UDFArgumentTypeException(2, "List type is expected for the third argument Vjfi: " + typeInfo[2].getTypeName());
        }
        ListTypeInfo typeInfo1 = (ListTypeInfo)typeInfo[1];
        if (!HiveUtils.isFloatingPointTypeInfo(typeInfo1.getListElementTypeInfo())) {
            throw new UDFArgumentTypeException(1, "Double or Float type is expected for the element type of list Vifj: " + typeInfo1.getTypeName());
        }
        ListTypeInfo typeInfo2 = (ListTypeInfo)typeInfo[2];
        if (!HiveUtils.isFloatingPointTypeInfo(typeInfo2.getListElementTypeInfo())) {
            throw new UDFArgumentTypeException(2, "Double or Float type is expected for the element type of list Vjfi: " + typeInfo1.getTypeName());
        }
        if (!HiveUtils.isNumberTypeInfo(typeInfo[3])) {
            throw new UDFArgumentTypeException(3, "Number type is expected for the third argument Xi: " + typeInfo[3].getTypeName());
        }
        if (!HiveUtils.isNumberTypeInfo(typeInfo[4])) {
            throw new UDFArgumentTypeException(4, "Number type is expected for the third argument Xi: " + typeInfo[4].getTypeName());
        }
        return new Evaluator();
    }

    @GenericUDAFEvaluator.AggregationType(estimable=true)
    public static final class FFMPredictAggregationBuffer
    extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        private double sum;

        FFMPredictAggregationBuffer() {
        }

        void reset() {
            this.sum = 0.0;
        }

        void merge(double o_sum) {
            this.sum += o_sum;
        }

        double get() {
            return this.sum;
        }

        void addW0(double W0) {
            this.sum += W0;
        }

        void addWiXi(double Wi, double Xi) {
            this.sum += Wi * Xi;
        }

        void addViVjXiXj(@Nonnull float[] Vij, @Nonnull float[] Vji, double Xi, double Xj) throws UDFArgumentException {
            if (Vij.length != Vji.length) {
                throw new UDFArgumentException("Mismatch in the number of factors");
            }
            int factors = Vij.length;
            double prod = 0.0;
            for (int f = 0; f < factors; ++f) {
                prod += (double)(Vij[f] * Vji[f]);
            }
            this.sum += prod * Xi * Xj;
        }

        public int estimate() {
            return 8;
        }
    }

    public static final class Evaluator
    extends GenericUDAFEvaluator {
        private PrimitiveObjectInspector wiOI;
        private ListObjectInspector vijOI;
        private ListObjectInspector vjiOI;
        private PrimitiveObjectInspector vijElemOI;
        private PrimitiveObjectInspector vjiElemOI;
        private PrimitiveObjectInspector xiOI;
        private PrimitiveObjectInspector xjOI;
        private DoubleObjectInspector mergeInputOI;

        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] parameters) throws HiveException {
            assert (parameters.length == 5);
            super.init(mode, parameters);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.wiOI = HiveUtils.asDoubleCompatibleOI(parameters, 0);
                this.vijOI = HiveUtils.asListOI(parameters, 1);
                this.vijElemOI = HiveUtils.asFloatingPointOI(this.vijOI.getListElementObjectInspector());
                this.vjiOI = HiveUtils.asListOI(parameters, 2);
                this.vjiElemOI = HiveUtils.asFloatingPointOI(this.vjiOI.getListElementObjectInspector());
                this.xiOI = HiveUtils.asDoubleCompatibleOI(parameters, 3);
                this.xjOI = HiveUtils.asDoubleCompatibleOI(parameters, 4);
            } else {
                this.mergeInputOI = HiveUtils.asDoubleOI(parameters, 0);
            }
            return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
        }

        public FFMPredictAggregationBuffer getNewAggregationBuffer() throws HiveException {
            FFMPredictAggregationBuffer myAggr = new FFMPredictAggregationBuffer();
            this.reset((GenericUDAFEvaluator.AggregationBuffer)myAggr);
            return myAggr;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer)agg;
            myAggr.reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer agg, Object[] parameters) throws HiveException {
            FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer)agg;
            if (parameters[0] == null) {
                if (parameters[3] == null || parameters[4] == null) {
                    return;
                }
                if (parameters[1] == null || parameters[2] == null) {
                    return;
                }
                float[] vij = HiveUtils.asFloatArray(parameters[1], this.vijOI, this.vijElemOI, false);
                float[] vji = HiveUtils.asFloatArray(parameters[2], this.vjiOI, this.vjiElemOI, false);
                double xi = PrimitiveObjectInspectorUtils.getDouble((Object)parameters[3], (PrimitiveObjectInspector)this.xiOI);
                double xj = PrimitiveObjectInspectorUtils.getDouble((Object)parameters[4], (PrimitiveObjectInspector)this.xjOI);
                myAggr.addViVjXiXj(vij, vji, xi, xj);
            } else {
                double wi = PrimitiveObjectInspectorUtils.getDouble((Object)parameters[0], (PrimitiveObjectInspector)this.wiOI);
                if (parameters[3] == null && parameters[4] == null) {
                    myAggr.addW0(wi);
                } else if (parameters[4] == null) {
                    double xi = PrimitiveObjectInspectorUtils.getDouble((Object)parameters[3], (PrimitiveObjectInspector)this.xiOI);
                    myAggr.addWiXi(wi, xi);
                }
            }
        }

        public DoubleWritable terminatePartial(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer)agg;
            double sum = myAggr.get();
            return new DoubleWritable(sum);
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer agg, Object partial) throws HiveException {
            if (partial == null) {
                return;
            }
            FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer)agg;
            double sum = this.mergeInputOI.get(partial);
            myAggr.merge(sum);
        }

        public DoubleWritable terminate(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer)agg;
            double result = myAggr.get();
            return new DoubleWritable(result);
        }
    }
}

