/*
 * Decompiled with CFR 0.152.
 */
package hivemall.evaluation;

import hivemall.UDAFEvaluatorWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Primitives;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.util.JavaDataModel;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;

@Description(name="fmeasure", value="_FUNC_(array|int|boolean actual, array|int| boolean predicted [, const string options]) - Return a F-measure (f1score is the special with beta=1.0)")
public final class FMeasureUDAF
extends AbstractGenericUDAFResolver {
    public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException {
        boolean isArg2ListOrIntOrBoolean;
        boolean isArg1ListOrIntOrBoolean;
        if (typeInfo.length != 2 && typeInfo.length != 3) {
            throw new UDFArgumentTypeException(typeInfo.length - 1, "_FUNC_ takes two or three arguments");
        }
        boolean bl = isArg1ListOrIntOrBoolean = HiveUtils.isListTypeInfo(typeInfo[0]) || HiveUtils.isIntegerTypeInfo(typeInfo[0]) || HiveUtils.isBooleanTypeInfo(typeInfo[0]);
        if (!isArg1ListOrIntOrBoolean) {
            throw new UDFArgumentTypeException(0, "The first argument `array/int/boolean actual` is invalid form: " + typeInfo[0]);
        }
        boolean bl2 = isArg2ListOrIntOrBoolean = HiveUtils.isListTypeInfo(typeInfo[1]) || HiveUtils.isIntegerTypeInfo(typeInfo[1]) || HiveUtils.isBooleanTypeInfo(typeInfo[1]);
        if (!isArg2ListOrIntOrBoolean) {
            throw new UDFArgumentTypeException(1, "The second argument `array/int/boolean predicted` is invalid form: " + typeInfo[1]);
        }
        if (!typeInfo[0].equals((Object)typeInfo[1])) {
            throw new UDFArgumentTypeException(1, "The first argument `actual`'s type is " + typeInfo[0] + ", but the second argument `predicted`'s type is not match: " + typeInfo[1]);
        }
        return new Evaluator();
    }

    @GenericUDAFEvaluator.AggregationType(estimable=true)
    public static class FMeasureAggregationBuffer
    extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        long tp;
        long totalActual;
        long totalPredicted;
        double beta;
        String average;

        public int estimate() {
            JavaDataModel model = JavaDataModel.get();
            return model.primitive2() * 4 + model.lengthFor(this.average);
        }

        void setOptions(double beta, String average) {
            this.beta = beta;
            this.average = average;
        }

        void reset() {
            this.tp = 0L;
            this.totalActual = 0L;
            this.totalPredicted = 0L;
        }

        void merge(long o_tp, long o_actual, long o_predicted, double beta, String average) {
            this.tp += o_tp;
            this.totalActual += o_actual;
            this.totalPredicted += o_predicted;
            this.beta = beta;
            this.average = average;
        }

        double get() {
            double numerator;
            double divisor;
            double squareBeta = this.beta * this.beta;
            if ("micro".equals(this.average)) {
                divisor = FMeasureAggregationBuffer.denom(this.tp, this.totalActual, this.totalPredicted, squareBeta);
                numerator = (1.0 + squareBeta) * (double)this.tp;
            } else {
                double precision = FMeasureAggregationBuffer.precision(this.tp, this.totalPredicted);
                double recall = FMeasureAggregationBuffer.recall(this.tp, this.totalActual);
                divisor = squareBeta * precision + recall;
                numerator = (1.0 + squareBeta) * precision * recall;
            }
            if (divisor > 0.0) {
                return numerator / divisor;
            }
            return 0.0;
        }

        private static double denom(long tp, long totalActual, long totalPredicted, double squareBeta) {
            long lp = totalActual - tp;
            long pl = totalPredicted - tp;
            return squareBeta * (double)(tp + lp) + (double)tp + (double)pl;
        }

        private static double precision(long tp, long totalPredicted) {
            return totalPredicted == 0L ? 0.0 : (double)tp / (double)totalPredicted;
        }

        private static double recall(long tp, long totalActual) {
            return totalActual == 0L ? 0.0 : (double)tp / (double)totalActual;
        }

        void iterate(@Nonnull List<?> actual, @Nonnull List<?> predicted) {
            int numActual = actual.size();
            int numPredicted = predicted.size();
            int countTp = 0;
            for (Object p : predicted) {
                if (!actual.contains(p)) continue;
                ++countTp;
            }
            this.tp += (long)countTp;
            this.totalActual += (long)numActual;
            this.totalPredicted += (long)numPredicted;
        }
    }

    public static class Evaluator
    extends UDAFEvaluatorWithOptions {
        private ObjectInspector actualOI;
        private ObjectInspector predictedOI;
        private StructObjectInspector internalMergeOI;
        private StructField tpField;
        private StructField totalActualField;
        private StructField totalPredictedField;
        private StructField betaOptionField;
        private StructField averageOptionFiled;
        private double beta;
        private String average;

        @Override
        protected Options getOptions() {
            Options opts = new Options();
            opts.addOption("beta", true, "The weight of precision [default: 1.]");
            opts.addOption("average", true, "The way of average calculation [default: micro]");
            return opts;
        }

        @Override
        protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
            CommandLine cl = null;
            double beta = 1.0;
            String average = "micro";
            if (argOIs.length >= 3) {
                String rawArgs = HiveUtils.getConstString(argOIs[2]);
                cl = this.parseOptions(rawArgs);
                beta = Primitives.parseDouble(cl.getOptionValue("beta"), beta);
                if (beta <= 0.0) {
                    throw new UDFArgumentException("The third argument `double beta` must be greater than 0.0: " + beta);
                }
                if ((average = cl.getOptionValue("average", average)).equals("macro")) {
                    throw new UDFArgumentException("\"-average macro\" is not supported");
                }
                if (!average.equals("binary") && !average.equals("micro")) {
                    throw new UDFArgumentException("The third argument `String average` must be one of the {binary, micro, macro}: " + average);
                }
            }
            this.beta = beta;
            this.average = average;
            return cl;
        }

        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] parameters) throws HiveException {
            assert (parameters.length == 2 || parameters.length == 3) : parameters.length;
            super.init(mode, parameters);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.processOptions(parameters);
                this.actualOI = parameters[0];
                this.predictedOI = parameters[1];
            } else {
                StructObjectInspector soi;
                this.internalMergeOI = soi = (StructObjectInspector)parameters[0];
                this.tpField = soi.getStructFieldRef("tp");
                this.totalActualField = soi.getStructFieldRef("totalActual");
                this.totalPredictedField = soi.getStructFieldRef("totalPredicted");
                this.betaOptionField = soi.getStructFieldRef("beta");
                this.averageOptionFiled = soi.getStructFieldRef("average");
            }
            Object outputOI = mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2 ? Evaluator.internalMergeOI() : PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
            return outputOI;
        }

        @Nonnull
        private static StructObjectInspector internalMergeOI() {
            ArrayList<String> fieldNames = new ArrayList<String>();
            ArrayList<Object> fieldOIs = new ArrayList<Object>();
            fieldNames.add("tp");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
            fieldNames.add("totalActual");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
            fieldNames.add("totalPredicted");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
            fieldNames.add("beta");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            fieldNames.add("average");
            fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
            return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
        }

        public FMeasureAggregationBuffer getNewAggregationBuffer() throws HiveException {
            FMeasureAggregationBuffer myAggr = new FMeasureAggregationBuffer();
            this.reset((GenericUDAFEvaluator.AggregationBuffer)myAggr);
            return myAggr;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer)agg;
            myAggr.reset();
            myAggr.setOptions(this.beta, this.average);
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer agg, Object[] parameters) throws HiveException {
            List<Integer> predicted;
            List<Integer> actual;
            boolean isList;
            FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer)agg;
            boolean bl = isList = HiveUtils.isListOI(this.actualOI) && HiveUtils.isListOI(this.predictedOI);
            if (isList) {
                if ("binary".equals(this.average)) {
                    throw new UDFArgumentException("\"-average binary\" is not supported when `predict` is array");
                }
                actual = ((ListObjectInspector)this.actualOI).getList(parameters[0]);
                predicted = ((ListObjectInspector)this.predictedOI).getList(parameters[1]);
            } else if (HiveUtils.isBooleanOI(this.actualOI)) {
                actual = Arrays.asList(Evaluator.asIntLabel(parameters[0], (BooleanObjectInspector)this.actualOI));
                predicted = Arrays.asList(Evaluator.asIntLabel(parameters[1], (BooleanObjectInspector)this.predictedOI));
            } else {
                int actualLabel = Evaluator.asIntLabel(parameters[0], HiveUtils.asIntegerOI(this.actualOI));
                actual = actualLabel == 0 && "binary".equals(this.average) ? Collections.emptyList() : Arrays.asList(actualLabel);
                int predictedLabel = Evaluator.asIntLabel(parameters[1], HiveUtils.asIntegerOI(this.predictedOI));
                predicted = predictedLabel == 0 && "binary".equals(this.average) ? Collections.emptyList() : Arrays.asList(predictedLabel);
            }
            myAggr.iterate(actual, predicted);
        }

        private static int asIntLabel(@Nonnull Object o, @Nonnull BooleanObjectInspector booleanOI) {
            if (booleanOI.get(o)) {
                return 1;
            }
            return 0;
        }

        private static int asIntLabel(@Nonnull Object o, @Nonnull PrimitiveObjectInspector intOI) throws UDFArgumentException {
            int value = PrimitiveObjectInspectorUtils.getInt((Object)o, (PrimitiveObjectInspector)intOI);
            switch (value) {
                case 1: {
                    return 1;
                }
                case -1: 
                case 0: {
                    return 0;
                }
            }
            throw new UDFArgumentException("Int label must be 1, 0 or -1: " + value);
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer)agg;
            Object[] partialResult = new Object[]{new LongWritable(myAggr.tp), new LongWritable(myAggr.totalActual), new LongWritable(myAggr.totalPredicted), new DoubleWritable(myAggr.beta), myAggr.average};
            return partialResult;
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer agg, Object partial) throws HiveException {
            if (partial == null) {
                return;
            }
            Object tpObj = this.internalMergeOI.getStructFieldData(partial, this.tpField);
            Object totalActualObj = this.internalMergeOI.getStructFieldData(partial, this.totalActualField);
            Object totalPredictedObj = this.internalMergeOI.getStructFieldData(partial, this.totalPredictedField);
            Object betaObj = this.internalMergeOI.getStructFieldData(partial, this.betaOptionField);
            Object averageObj = this.internalMergeOI.getStructFieldData(partial, this.averageOptionFiled);
            long tp = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(tpObj);
            long totalActual = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(totalActualObj);
            long totalPredicted = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(totalPredictedObj);
            double beta = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(betaObj);
            String average = PrimitiveObjectInspectorFactory.writableStringObjectInspector.getPrimitiveJavaObject(averageObj);
            FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer)agg;
            myAggr.merge(tp, totalActual, totalPredicted, beta, average);
        }

        public DoubleWritable terminate(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer)agg;
            double result = myAggr.get();
            return new DoubleWritable(result);
        }
    }
}

