/*
 * Decompiled with CFR 0.152.
 */
package hivemall.ftvec.selection;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Preconditions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;

@Description(name="snr", value="_FUNC_(array<number> features, array<int> one-hot class label) - Returns Signal Noise Ratio for each feature as array<double>")
public class SignalNoiseRatioUDAF
extends AbstractGenericUDAFResolver {
    public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) throws SemanticException {
        ObjectInspector[] OIs = info.getParameterObjectInspectors();
        if (OIs.length != 2) {
            throw new UDFArgumentLengthException("Specify two arguments: " + OIs.length);
        }
        if (!HiveUtils.isNumberListOI(OIs[0])) {
            throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but " + OIs[0].getTypeName() + " was passed as `features`");
        }
        if (!HiveUtils.isListOI(OIs[1]) || !HiveUtils.isIntegerOI(((ListObjectInspector)OIs[1]).getListElementObjectInspector())) {
            throw new UDFArgumentTypeException(1, "Only array<int> type argument is acceptable but " + OIs[1].getTypeName() + " was passed as `labels`");
        }
        return new SignalNoiseRatioUDAFEvaluator();
    }

    static class SignalNoiseRatioUDAFEvaluator
    extends GenericUDAFEvaluator {
        private ListObjectInspector featuresOI;
        private PrimitiveObjectInspector featureOI;
        private ListObjectInspector labelsOI;
        private PrimitiveObjectInspector labelOI;
        private StructObjectInspector structOI;
        private StructField countsField;
        private StructField meansField;
        private StructField variancesField;
        private ListObjectInspector countsOI;
        private LongObjectInspector countOI;
        private ListObjectInspector meansOI;
        private ListObjectInspector meanListOI;
        private DoubleObjectInspector meanElemOI;
        private ListObjectInspector variancesOI;
        private ListObjectInspector varianceListOI;
        private DoubleObjectInspector varianceElemOI;

        SignalNoiseRatioUDAFEvaluator() {
        }

        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] OIs) throws HiveException {
            super.init(mode, OIs);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.featuresOI = HiveUtils.asListOI(OIs[0]);
                this.featureOI = HiveUtils.asDoubleCompatibleOI(this.featuresOI.getListElementObjectInspector());
                this.labelsOI = HiveUtils.asListOI(OIs[1]);
                this.labelOI = HiveUtils.asIntegerOI(this.labelsOI.getListElementObjectInspector());
            } else {
                this.structOI = (StructObjectInspector)OIs[0];
                this.countsField = this.structOI.getStructFieldRef("counts");
                this.countsOI = HiveUtils.asListOI(this.countsField.getFieldObjectInspector());
                this.countOI = HiveUtils.asLongOI(this.countsOI.getListElementObjectInspector());
                this.meansField = this.structOI.getStructFieldRef("means");
                this.meansOI = HiveUtils.asListOI(this.meansField.getFieldObjectInspector());
                this.meanListOI = HiveUtils.asListOI(this.meansOI.getListElementObjectInspector());
                this.meanElemOI = HiveUtils.asDoubleOI(this.meanListOI.getListElementObjectInspector());
                this.variancesField = this.structOI.getStructFieldRef("variances");
                this.variancesOI = HiveUtils.asListOI(this.variancesField.getFieldObjectInspector());
                this.varianceListOI = HiveUtils.asListOI(this.variancesOI.getListElementObjectInspector());
                this.varianceElemOI = HiveUtils.asDoubleOI(this.varianceListOI.getListElementObjectInspector());
            }
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2) {
                ArrayList<StandardListObjectInspector> fieldOIs = new ArrayList<StandardListObjectInspector>();
                fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableLongObjectInspector));
                fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
                fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
                return ObjectInspectorFactory.getStandardStructObjectInspector(Arrays.asList("counts", "means", "variances"), fieldOIs);
            }
            return ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
        }

        public GenericUDAFEvaluator.AbstractAggregationBuffer getNewAggregationBuffer() throws HiveException {
            SignalNoiseRatioAggregationBuffer myAgg = new SignalNoiseRatioAggregationBuffer();
            this.reset((GenericUDAFEvaluator.AggregationBuffer)myAgg);
            return myAgg;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer)agg;
            myAgg.reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer agg, Object[] parameters) throws HiveException {
            Object featuresObj = parameters[0];
            Object labelsObj = parameters[1];
            Preconditions.checkNotNull(featuresObj);
            Preconditions.checkNotNull(labelsObj);
            SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer)agg;
            List labels = this.labelsOI.getList(labelsObj);
            int nClasses = labels.size();
            Preconditions.checkArgument(nClasses >= 2, UDFArgumentException.class);
            List features = this.featuresOI.getList(featuresObj);
            int nFeatures = features.size();
            Preconditions.checkArgument(nFeatures >= 1, UDFArgumentException.class);
            if (myAgg.counts == null) {
                myAgg.init(nClasses, nFeatures);
            } else {
                Preconditions.checkArgument(nClasses == myAgg.counts.length, UDFArgumentException.class);
                Preconditions.checkArgument(nFeatures == myAgg.means[0].length, UDFArgumentException.class);
            }
            int clazz = SignalNoiseRatioUDAFEvaluator.hotIndex(labels, this.labelOI);
            long n = myAgg.counts[clazz];
            int n2 = clazz;
            myAgg.counts[n2] = myAgg.counts[n2] + 1L;
            for (int i = 0; i < nFeatures; ++i) {
                double x = PrimitiveObjectInspectorUtils.getDouble(features.get(i), (PrimitiveObjectInspector)this.featureOI);
                double meanN = myAgg.means[clazz][i];
                double varianceN = myAgg.variances[clazz][i];
                myAgg.means[clazz][i] = ((double)n * meanN + x) / ((double)n + 1.0);
                myAgg.variances[clazz][i] = ((double)n * varianceN + (x - meanN) * (x - myAgg.means[clazz][i])) / ((double)n + 1.0);
            }
        }

        private static int hotIndex(@Nonnull List<?> labels, PrimitiveObjectInspector labelOI) throws UDFArgumentException {
            int nClasses = labels.size();
            int clazz = -1;
            for (int i = 0; i < nClasses; ++i) {
                int label = PrimitiveObjectInspectorUtils.getInt(labels.get(i), (PrimitiveObjectInspector)labelOI);
                if (label == 1) {
                    if (clazz != -1) {
                        throw new UDFArgumentException("Specify one-hot vectorized array. Multiple hot elements found.");
                    }
                    clazz = i;
                    continue;
                }
                if (label == 0) continue;
                throw new UDFArgumentException("Assumed one-hot encoding (0/1) but found an invalid label: " + label);
            }
            if (clazz == -1) {
                throw new UDFArgumentException("Specify one-hot vectorized array for label. Hot element not found.");
            }
            return clazz;
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer agg, Object other) throws HiveException {
            if (other == null) {
                return;
            }
            SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer)agg;
            List counts = this.countsOI.getList(this.structOI.getStructFieldData(other, this.countsField));
            List means = this.meansOI.getList(this.structOI.getStructFieldData(other, this.meansField));
            List variances = this.variancesOI.getList(this.structOI.getStructFieldData(other, this.variancesField));
            int nClasses = counts.size();
            int nFeatures = this.meanListOI.getListLength(means.get(0));
            if (myAgg.counts == null) {
                myAgg.init(nClasses, nFeatures);
            }
            for (int i = 0; i < nClasses; ++i) {
                long n = myAgg.counts[i];
                long cnt = PrimitiveObjectInspectorUtils.getLong(counts.get(i), (PrimitiveObjectInspector)this.countOI);
                if (cnt == 0L) continue;
                List mean = this.meanListOI.getList(means.get(i));
                List variance = this.varianceListOI.getList(variances.get(i));
                int n2 = i;
                myAgg.counts[n2] = myAgg.counts[n2] + cnt;
                for (int j = 0; j < nFeatures; ++j) {
                    double meanN = myAgg.means[i][j];
                    double meanM = PrimitiveObjectInspectorUtils.getDouble(mean.get(j), (PrimitiveObjectInspector)this.meanElemOI);
                    double varianceN = myAgg.variances[i][j];
                    double varianceM = PrimitiveObjectInspectorUtils.getDouble(variance.get(j), (PrimitiveObjectInspector)this.varianceElemOI);
                    if (n == 0L) {
                        myAgg.means[i][j] = meanM;
                        myAgg.variances[i][j] = varianceM;
                        continue;
                    }
                    myAgg.means[i][j] = ((double)n * meanN + (double)cnt * meanM) / (double)(n + cnt);
                    myAgg.variances[i][j] = (varianceN * (double)(n - 1L) + varianceM * (double)(cnt - 1L) + Math.pow(meanN - meanM, 2.0) * (double)n * (double)cnt / (double)(n + cnt)) / (double)(n + cnt - 1L);
                }
            }
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer)agg;
            Object[] partialResult = new Object[3];
            partialResult[0] = WritableUtils.toWritableList(myAgg.counts);
            ArrayList<List<DoubleWritable>> means = new ArrayList<List<DoubleWritable>>();
            for (double[] mean : myAgg.means) {
                means.add(WritableUtils.toWritableList(mean));
            }
            partialResult[1] = means;
            ArrayList<List<DoubleWritable>> variances = new ArrayList<List<DoubleWritable>>();
            for (double[] variance : myAgg.variances) {
                variances.add(WritableUtils.toWritableList(variance));
            }
            partialResult[2] = variances;
            return partialResult;
        }

        public Object terminate(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer)agg;
            int nClasses = myAgg.counts.length;
            int nFeatures = myAgg.means[0].length;
            double[] result = new double[nFeatures];
            double[] sds = new double[nClasses];
            for (int i = 0; i < nFeatures; ++i) {
                sds[0] = Math.sqrt(myAgg.variances[0][i]);
                for (int j = 1; j < nClasses; ++j) {
                    sds[j] = Math.sqrt(myAgg.variances[j][i]);
                    if (myAgg.counts[j] == 0L) continue;
                    for (int k = 0; k < j; ++k) {
                        double snr;
                        if (myAgg.counts[k] == 0L || myAgg.counts[j] == 1L && myAgg.counts[k] == 1L || Double.isNaN(snr = Math.abs(myAgg.means[j][i] - myAgg.means[k][i]) / (sds[j] + sds[k]))) continue;
                        int n = i;
                        result[n] = result[n] + snr;
                    }
                }
            }
            return WritableUtils.toWritableList(result);
        }

        @GenericUDAFEvaluator.AggregationType(estimable=true)
        static class SignalNoiseRatioAggregationBuffer
        extends GenericUDAFEvaluator.AbstractAggregationBuffer {
            long[] counts;
            double[][] means;
            double[][] variances;

            SignalNoiseRatioAggregationBuffer() {
            }

            public int estimate() {
                return this.counts == null ? 0 : 8 * this.counts.length + 8 * this.means.length * this.means[0].length + 8 * this.variances.length * this.variances[0].length;
            }

            public void init(int nClasses, int nFeatures) {
                this.counts = new long[nClasses];
                this.means = new double[nClasses][nFeatures];
                this.variances = new double[nClasses][nFeatures];
            }

            public void reset() {
                if (this.counts != null) {
                    Arrays.fill(this.counts, 0L);
                    for (double[] mean : this.means) {
                        Arrays.fill(mean, 0.0);
                    }
                    for (double[] variance : this.variances) {
                        Arrays.fill(variance, 0.0);
                    }
                }
            }
        }
    }
}

