/*
 * Decompiled with CFR 0.152.
 */
package hivemall.tools.array;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableLongObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.IntWritable;

@Description(name="array_avg", value="_FUNC_(array<number>) - Returns an array<double> in which each element is the mean of a set of numbers", extended="WITH input as (\n  select array(1.0, 2.0, 3.0) as nums\n  UNION ALL\n  select array(2.0, 3.0, 4.0) as nums\n)\nselect\n  array_avg(nums)\nfrom\n  input;\n\n[\"1.5\",\"2.5\",\"3.5\"]")
public final class ArrayAvgGenericUDAF
extends AbstractGenericUDAFResolver {
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] typeInfo) throws SemanticException {
        if (typeInfo.length != 1) {
            throw new UDFArgumentTypeException(typeInfo.length - 1, "One argument is expected, taking an array as an argument");
        }
        if (!typeInfo[0].getCategory().equals((Object)ObjectInspector.Category.LIST)) {
            throw new UDFArgumentTypeException(typeInfo.length - 1, "One argument is expected, taking an array as an argument");
        }
        return new Evaluator();
    }

    @GenericUDAFEvaluator.AggregationType(estimable=true)
    public static final class ArrayAvgAggregationBuffer
    extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        int _size;
        double[] _sum;
        long[] _count;

        void reset() {
            this._size = -1;
            this._sum = null;
            this._count = null;
        }

        void init(int size) throws HiveException {
            assert (size > 0) : size;
            this._size = size;
            this._sum = new double[size];
            this._count = new long[size];
        }

        void doIterate(@Nonnull Object tuple, @Nonnull ListObjectInspector listOI, @Nonnull PrimitiveObjectInspector elemOI) throws HiveException {
            int size = listOI.getListLength(tuple);
            if (this._size == -1) {
                this.init(size);
            }
            if (size != this._size) {
                throw new HiveException("Mismatch in the number of elements at tuple: " + tuple.toString());
            }
            double[] sum = this._sum;
            long[] count = this._count;
            int len = size;
            for (int i = 0; i < len; ++i) {
                Object o = listOI.getListElement(tuple, i);
                if (o == null) continue;
                double v = PrimitiveObjectInspectorUtils.getDouble((Object)o, (PrimitiveObjectInspector)elemOI);
                int n = i;
                sum[n] = sum[n] + v;
                int n2 = i;
                count[n2] = count[n2] + 1L;
            }
        }

        void merge(int o_size, @Nonnull Object o_sum, @Nonnull Object o_count, @Nonnull StandardListObjectInspector sumOI, @Nonnull StandardListObjectInspector countOI) throws HiveException {
            WritableDoubleObjectInspector sumElemOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
            WritableLongObjectInspector countElemOI = PrimitiveObjectInspectorFactory.writableLongObjectInspector;
            if (o_size != this._size) {
                if (this._size == -1) {
                    this.init(o_size);
                } else {
                    throw new HiveException("Mismatch in the number of elements");
                }
            }
            double[] sum = this._sum;
            long[] count = this._count;
            int i = 0;
            int len = this._size;
            while (i < len) {
                Object sum_e = sumOI.getListElement(o_sum, i);
                int n = i;
                sum[n] = sum[n] + sumElemOI.get(sum_e);
                Object count_e = countOI.getListElement(o_count, i);
                int n2 = i++;
                count[n2] = count[n2] + countElemOI.get(count_e);
            }
        }

        public int estimate() {
            if (this._size == -1) {
                return 8;
            }
            return 4 + 2 * (32 + 8 * this._size);
        }
    }

    public static class Evaluator
    extends GenericUDAFEvaluator {
        private ListObjectInspector inputListOI;
        private PrimitiveObjectInspector inputListElemOI;
        private StructObjectInspector internalMergeOI;
        private StructField sizeField;
        private StructField sumField;
        private StructField countField;
        private WritableIntObjectInspector sizeOI;
        private StandardListObjectInspector sumOI;
        private StandardListObjectInspector countOI;

        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] parameters) throws HiveException {
            assert (parameters.length == 1);
            super.init(mode, parameters);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.inputListOI = (ListObjectInspector)parameters[0];
                this.inputListElemOI = HiveUtils.asDoubleCompatibleOI(this.inputListOI.getListElementObjectInspector());
            } else {
                StructObjectInspector soi;
                this.internalMergeOI = soi = (StructObjectInspector)parameters[0];
                this.sizeField = soi.getStructFieldRef("size");
                this.sumField = soi.getStructFieldRef("sum");
                this.countField = soi.getStructFieldRef("count");
                this.sizeOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector;
                this.sumOI = ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                this.countOI = ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableLongObjectInspector);
            }
            Object outputOI = mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2 ? Evaluator.internalMergeOI() : ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            return outputOI;
        }

        private static StructObjectInspector internalMergeOI() {
            ArrayList<String> fieldNames = new ArrayList<String>();
            ArrayList<Object> fieldOIs = new ArrayList<Object>();
            fieldNames.add("size");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
            fieldNames.add("sum");
            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
            fieldNames.add("count");
            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableLongObjectInspector));
            return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
        }

        public ArrayAvgAggregationBuffer getNewAggregationBuffer() throws HiveException {
            ArrayAvgAggregationBuffer aggr = new ArrayAvgAggregationBuffer();
            this.reset((GenericUDAFEvaluator.AggregationBuffer)aggr);
            return aggr;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer aggr) throws HiveException {
            ArrayAvgAggregationBuffer myAggr = (ArrayAvgAggregationBuffer)aggr;
            myAggr.reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer aggr, Object[] parameters) throws HiveException {
            ArrayAvgAggregationBuffer myAggr = (ArrayAvgAggregationBuffer)aggr;
            Object tuple = parameters[0];
            if (tuple != null) {
                myAggr.doIterate(tuple, this.inputListOI, this.inputListElemOI);
            }
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer aggr) throws HiveException {
            ArrayAvgAggregationBuffer myAggr = (ArrayAvgAggregationBuffer)aggr;
            if (myAggr._size == -1) {
                return null;
            }
            Object[] partialResult = new Object[]{new IntWritable(myAggr._size), WritableUtils.toWritableList(myAggr._sum), WritableUtils.toWritableList(myAggr._count)};
            return partialResult;
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer aggr, Object partial) throws HiveException {
            if (partial != null) {
                ArrayAvgAggregationBuffer myAggr = (ArrayAvgAggregationBuffer)aggr;
                Object o1 = this.internalMergeOI.getStructFieldData(partial, this.sizeField);
                int size = this.sizeOI.get(o1);
                assert (size != -1);
                Object sum = this.internalMergeOI.getStructFieldData(partial, this.sumField);
                Object count = this.internalMergeOI.getStructFieldData(partial, this.countField);
                if (sum instanceof LazyBinaryArray) {
                    sum = ((LazyBinaryArray)sum).getList();
                }
                if (count instanceof LazyBinaryArray) {
                    count = ((LazyBinaryArray)count).getList();
                }
                myAggr.merge(size, sum, count, this.sumOI, this.countOI);
            }
        }

        public List<DoubleWritable> terminate(GenericUDAFEvaluator.AggregationBuffer aggr) throws HiveException {
            ArrayAvgAggregationBuffer myAggr = (ArrayAvgAggregationBuffer)aggr;
            int size = myAggr._size;
            if (size == -1) {
                return null;
            }
            double[] sum = myAggr._sum;
            long[] count = myAggr._count;
            DoubleWritable[] ary = new DoubleWritable[size];
            for (int i = 0; i < size; ++i) {
                long c = count[i];
                float avg = c == 0L ? 0.0f : (float)(sum[i] / (double)c);
                ary[i] = new DoubleWritable((double)avg);
            }
            return Arrays.asList(ary);
        }
    }
}

