/*
 * Decompiled with CFR 0.152.
 */
package hivemall.factorization.fm;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import java.util.ArrayList;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;

@Description(name="fm_predict", value="_FUNC_(Float Wj, array<float> Vjf, float Xj) - Returns a prediction value in Double")
public final class FMPredictGenericUDAF
extends AbstractGenericUDAFResolver {
    private FMPredictGenericUDAF() {
    }

    public Evaluator getEvaluator(TypeInfo[] typeInfo) throws SemanticException {
        if (typeInfo.length != 3) {
            throw new UDFArgumentLengthException("Expected argument length is 3 but given argument length was " + typeInfo.length);
        }
        if (!HiveUtils.isNumberTypeInfo(typeInfo[0])) {
            throw new UDFArgumentTypeException(0, "Number type is expected for the first argument Wj: " + typeInfo[0].getTypeName());
        }
        if (typeInfo[1].getCategory() != ObjectInspector.Category.LIST) {
            throw new UDFArgumentTypeException(1, "List type is expected for the second argument Vjf: " + typeInfo[1].getTypeName());
        }
        ListTypeInfo typeInfo1 = (ListTypeInfo)typeInfo[1];
        if (!HiveUtils.isNumberTypeInfo(typeInfo1.getListElementTypeInfo())) {
            throw new UDFArgumentTypeException(1, "Number type is expected for the element type of list Vjf: " + typeInfo1.getTypeName());
        }
        if (!HiveUtils.isNumberTypeInfo(typeInfo[2])) {
            throw new UDFArgumentTypeException(2, "Number type is expected for the third argument Xj: " + typeInfo[2].getTypeName());
        }
        return new Evaluator();
    }

    @GenericUDAFEvaluator.AggregationType(estimable=true)
    public static class FMPredictAggregationBuffer
    extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        private double ret;
        private double[] sumVjXj;
        private double[] sumV2X2;

        FMPredictAggregationBuffer() {
        }

        void reset() {
            this.ret = 0.0;
            this.sumVjXj = null;
            this.sumV2X2 = null;
        }

        void iterate(double Wj) {
            this.ret += Wj;
        }

        void iterate(double Wj, double Xj, @Nonnull Object Vif, @Nonnull ListObjectInspector vOI, @Nonnull PrimitiveObjectInspector vElemOI) throws HiveException {
            this.ret += Wj * Xj;
            int factors = vOI.getListLength(Vif);
            if (factors < 1) {
                throw new HiveException("# of Factor should be more than 0: " + factors);
            }
            if (this.sumVjXj == null) {
                this.sumVjXj = new double[factors];
                this.sumV2X2 = new double[factors];
            } else if (this.sumVjXj.length != factors) {
                throw new HiveException("Mismatch in the number of factors");
            }
            int f = 0;
            while (f < factors) {
                Object o = vOI.getListElement(Vif, f);
                if (o == null) {
                    throw new HiveException("Vj" + f + " should not be null");
                }
                double v = PrimitiveObjectInspectorUtils.getDouble((Object)o, (PrimitiveObjectInspector)vElemOI);
                double vx = v * Xj;
                int n = f;
                this.sumVjXj[n] = this.sumVjXj[n] + vx;
                int n2 = f++;
                this.sumV2X2[n2] = this.sumV2X2[n2] + vx * vx;
            }
        }

        void merge(double o_ret, @Nullable Object o_sumVjXj, @Nullable Object o_sumV2X2, @Nonnull StandardListObjectInspector sumVjXjOI, @Nonnull StandardListObjectInspector sumV2X2OI) throws HiveException {
            this.ret += o_ret;
            if (o_sumVjXj == null) {
                return;
            }
            if (o_sumV2X2 == null) {
                throw new HiveException("o_sumV2X2 should not be null");
            }
            int factors = sumVjXjOI.getListLength(o_sumVjXj);
            if (this.sumVjXj == null) {
                this.sumVjXj = new double[factors];
                this.sumV2X2 = new double[factors];
            } else if (this.sumVjXj.length != factors) {
                throw new HiveException("Mismatch in the number of factors");
            }
            WritableDoubleObjectInspector doubleOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
            int f = 0;
            while (f < factors) {
                Object o1 = sumVjXjOI.getListElement(o_sumVjXj, f);
                Object o2 = sumV2X2OI.getListElement(o_sumV2X2, f);
                double d1 = doubleOI.get(o1);
                double d2 = doubleOI.get(o2);
                int n = f;
                this.sumVjXj[n] = this.sumVjXj[n] + d1;
                int n2 = f++;
                this.sumV2X2[n2] = this.sumV2X2[n2] + d2;
            }
        }

        double getPrediction() {
            double predict = this.ret;
            if (this.sumVjXj != null) {
                int factors = this.sumVjXj.length;
                for (int f = 0; f < factors; ++f) {
                    double d1 = this.sumVjXj[f];
                    double d2 = this.sumV2X2[f];
                    predict += 0.5 * (d1 * d1 - d2);
                }
            }
            return predict;
        }

        public int estimate() {
            if (this.sumVjXj == null) {
                return 24;
            }
            return 8 + 2 * (32 + 8 * this.sumVjXj.length);
        }
    }

    public static class Evaluator
    extends GenericUDAFEvaluator {
        private PrimitiveObjectInspector wOI;
        private ListObjectInspector vOI;
        private PrimitiveObjectInspector vElemOI;
        private PrimitiveObjectInspector xOI;
        private StructObjectInspector internalMergeOI;
        private StructField retField;
        private StructField sumVjXjField;
        private StructField sumV2X2Field;
        private WritableDoubleObjectInspector retOI;
        private StandardListObjectInspector sumVjXjOI;
        private StandardListObjectInspector sumV2X2OI;

        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] parameters) throws HiveException {
            assert (parameters.length == 3);
            super.init(mode, parameters);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.wOI = HiveUtils.asDoubleCompatibleOI(parameters, 0);
                this.vOI = HiveUtils.asListOI(parameters, 1);
                this.vElemOI = HiveUtils.asDoubleCompatibleOI(this.vOI.getListElementObjectInspector());
                this.xOI = HiveUtils.asDoubleCompatibleOI(parameters, 2);
            } else {
                StructObjectInspector soi;
                this.internalMergeOI = soi = (StructObjectInspector)parameters[0];
                this.retField = soi.getStructFieldRef("ret");
                this.sumVjXjField = soi.getStructFieldRef("sumVjXj");
                this.sumV2X2Field = soi.getStructFieldRef("sumV2X2");
                this.retOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
                this.sumVjXjOI = ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                this.sumV2X2OI = ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            }
            Object outputOI = mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2 ? Evaluator.internalMergeOI() : PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
            return outputOI;
        }

        private static StructObjectInspector internalMergeOI() {
            ArrayList<String> fieldNames = new ArrayList<String>();
            ArrayList<Object> fieldOIs = new ArrayList<Object>();
            fieldNames.add("ret");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            fieldNames.add("sumVjXj");
            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
            fieldNames.add("sumV2X2");
            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
            return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
        }

        public FMPredictAggregationBuffer getNewAggregationBuffer() throws HiveException {
            FMPredictAggregationBuffer buf = new FMPredictAggregationBuffer();
            buf.reset();
            return buf;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            FMPredictAggregationBuffer buf = (FMPredictAggregationBuffer)agg;
            buf.reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer agg, Object[] parameters) throws HiveException {
            if (parameters[0] == null) {
                return;
            }
            FMPredictAggregationBuffer buf = (FMPredictAggregationBuffer)agg;
            double w = PrimitiveObjectInspectorUtils.getDouble((Object)parameters[0], (PrimitiveObjectInspector)this.wOI);
            if (parameters[1] == null || this.vOI.getListLength(parameters[1]) == 0) {
                buf.iterate(w);
            } else {
                if (parameters[2] == null) {
                    throw new UDFArgumentException("The third argument Xj must not be null");
                }
                double x = PrimitiveObjectInspectorUtils.getDouble((Object)parameters[2], (PrimitiveObjectInspector)this.xOI);
                buf.iterate(w, x, parameters[1], this.vOI, this.vElemOI);
            }
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            FMPredictAggregationBuffer buf = (FMPredictAggregationBuffer)agg;
            Object[] partialResult = new Object[3];
            partialResult[0] = new DoubleWritable(buf.ret);
            if (buf.sumVjXj != null) {
                partialResult[1] = WritableUtils.toWritableList(buf.sumVjXj);
                partialResult[2] = WritableUtils.toWritableList(buf.sumV2X2);
            }
            return partialResult;
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer agg, Object partial) throws HiveException {
            if (partial == null) {
                return;
            }
            FMPredictAggregationBuffer buf = (FMPredictAggregationBuffer)agg;
            Object o1 = this.internalMergeOI.getStructFieldData(partial, this.retField);
            double ret = this.retOI.get(o1);
            Object sumVjXj = this.internalMergeOI.getStructFieldData(partial, this.sumVjXjField);
            Object sumV2X2 = this.internalMergeOI.getStructFieldData(partial, this.sumV2X2Field);
            if (sumVjXj instanceof LazyBinaryArray) {
                sumVjXj = ((LazyBinaryArray)sumVjXj).getList();
            }
            if (sumV2X2 instanceof LazyBinaryArray) {
                sumV2X2 = ((LazyBinaryArray)sumV2X2).getList();
            }
            buf.merge(ret, sumVjXj, sumV2X2, this.sumVjXjOI, this.sumV2X2OI);
        }

        public DoubleWritable terminate(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            FMPredictAggregationBuffer buf = (FMPredictAggregationBuffer)agg;
            double predict = buf.getPrediction();
            return new DoubleWritable(predict);
        }
    }
}

