/*
 * Decompiled with CFR 0.152.
 */
package hivemall.topicmodel;

import hivemall.topicmodel.OnlineLDAModel;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.CommandLineUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.struct.KeySortablePair;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;

@Description(name="lda_predict", value="_FUNC_(string word, float value, int label, float lambda[, const string options]) - Returns a list which consists of <int label, float prob>")
public final class LDAPredictUDAF
extends AbstractGenericUDAFResolver {
    public Evaluator getEvaluator(TypeInfo[] typeInfo) throws SemanticException {
        if (typeInfo.length != 4 && typeInfo.length != 5) {
            throw new UDFArgumentLengthException("Expected argument length is 4 or 5 but given argument length was " + typeInfo.length);
        }
        if (!HiveUtils.isStringTypeInfo(typeInfo[0])) {
            throw new UDFArgumentTypeException(0, "String type is expected for the first argument word: " + typeInfo[0].getTypeName());
        }
        if (!HiveUtils.isNumberTypeInfo(typeInfo[1])) {
            throw new UDFArgumentTypeException(1, "Number type is expected for the second argument value: " + typeInfo[1].getTypeName());
        }
        if (!HiveUtils.isIntegerTypeInfo(typeInfo[2])) {
            throw new UDFArgumentTypeException(2, "Integer type is expected for the third argument label: " + typeInfo[2].getTypeName());
        }
        if (!HiveUtils.isNumberTypeInfo(typeInfo[3])) {
            throw new UDFArgumentTypeException(3, "Number type is expected for the forth argument lambda: " + typeInfo[3].getTypeName());
        }
        if (typeInfo.length == 5 && !HiveUtils.isStringTypeInfo(typeInfo[4])) {
            throw new UDFArgumentTypeException(4, "String type is expected for the fifth argument lambda: " + typeInfo[4].getTypeName());
        }
        return new Evaluator();
    }

    public static class OnlineLDAPredictAggregationBuffer
    extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        private List<String> wcList;
        private Map<String, List<Float>> lambdaMap;
        private int topics;
        private float alpha;
        private double delta;

        OnlineLDAPredictAggregationBuffer() {
        }

        void setOptions(int topics, float alpha, double delta) {
            this.topics = topics;
            this.alpha = alpha;
            this.delta = delta;
        }

        void reset() {
            this.wcList = new ArrayList<String>();
            this.lambdaMap = new HashMap<String, List<Float>>();
        }

        void iterate(String word, float value, int label, float lambda) {
            this.wcList.add(word + ":" + value);
            List<Float> lambda_word = this.lambdaMap.get(word);
            if (lambda_word == null) {
                lambda_word = new ArrayList<Float>(Collections.nCopies(this.topics, Float.valueOf(-1.0f)));
                this.lambdaMap.put(word, lambda_word);
            }
            lambda_word.set(label, Float.valueOf(lambda));
        }

        void merge(List<String> o_wcList, Map<String, List<Float>> o_lambdaMap) {
            this.wcList.addAll(o_wcList);
            for (Map.Entry<String, List<Float>> e : o_lambdaMap.entrySet()) {
                String o_word = e.getKey();
                List<Float> o_lambda_word = e.getValue();
                List<Float> lambda_word = this.lambdaMap.get(o_word);
                if (lambda_word == null) {
                    this.lambdaMap.put(o_word, o_lambda_word);
                    continue;
                }
                for (int k = 0; k < this.topics; ++k) {
                    float lambda_k = o_lambda_word.get(k).floatValue();
                    if (lambda_k == -1.0f) continue;
                    lambda_word.set(k, Float.valueOf(lambda_k));
                }
                this.lambdaMap.put(o_word, lambda_word);
            }
        }

        float[] get() {
            OnlineLDAModel model = new OnlineLDAModel(this.topics, this.alpha, this.delta);
            for (Map.Entry<String, List<Float>> e : this.lambdaMap.entrySet()) {
                String word = e.getKey();
                List<Float> lambda_word = e.getValue();
                for (int k = 0; k < this.topics; ++k) {
                    float lambda_k = lambda_word.get(k).floatValue();
                    if (lambda_k == -1.0f) continue;
                    model.setWordScore(word, k, lambda_k);
                }
            }
            String[] wcArray = this.wcList.toArray(new String[this.wcList.size()]);
            return model.getTopicDistribution(wcArray);
        }
    }

    public static class Evaluator
    extends GenericUDAFEvaluator {
        private PrimitiveObjectInspector wordOI;
        private PrimitiveObjectInspector valueOI;
        private PrimitiveObjectInspector labelOI;
        private PrimitiveObjectInspector lambdaOI;
        private int topics;
        private float alpha;
        private double delta;
        private StructObjectInspector internalMergeOI;
        private StructField wcListField;
        private StructField lambdaMapField;
        private StructField topicsOptionField;
        private StructField alphaOptionField;
        private StructField deltaOptionField;
        private PrimitiveObjectInspector wcListElemOI;
        private StandardListObjectInspector wcListOI;
        private StandardMapObjectInspector lambdaMapOI;
        private PrimitiveObjectInspector lambdaMapKeyOI;
        private StandardListObjectInspector lambdaMapValueOI;
        private PrimitiveObjectInspector lambdaMapValueElemOI;

        protected Options getOptions() {
            Options opts = new Options();
            opts.addOption("k", "topics", true, "The number of topics [default: 10]");
            opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]");
            opts.addOption("delta", true, "Check convergence in the expectation step [default: 1E-5]");
            return opts;
        }

        @Nonnull
        protected final CommandLine parseOptions(String optionValue) throws UDFArgumentException {
            String[] args = optionValue.split("\\s+");
            Options opts = this.getOptions();
            opts.addOption("help", false, "Show function help");
            CommandLine cl = CommandLineUtils.parseOptions(args, opts);
            if (cl.hasOption("help")) {
                String funcName;
                Description funcDesc = ((Object)((Object)this)).getClass().getAnnotation(Description.class);
                String cmdLineSyntax = funcDesc == null ? ((Object)((Object)this)).getClass().getSimpleName() : ((funcName = funcDesc.name()) == null ? ((Object)((Object)this)).getClass().getSimpleName() : funcDesc.value().replace("_FUNC_", funcDesc.name()));
                StringWriter sw = new StringWriter();
                sw.write(10);
                PrintWriter pw = new PrintWriter(sw);
                HelpFormatter formatter = new HelpFormatter();
                formatter.printHelp(pw, 74, cmdLineSyntax, null, opts, 1, 3, null, true);
                pw.flush();
                String helpMsg = sw.toString();
                throw new UDFArgumentException(helpMsg);
            }
            return cl;
        }

        protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
            CommandLine cl = null;
            if (argOIs.length >= 5) {
                String rawArgs = HiveUtils.getConstString(argOIs[4]);
                cl = this.parseOptions(rawArgs);
                this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 10);
                if (this.topics < 1) {
                    throw new UDFArgumentException("A positive integer MUST be set to an option `-topics`: " + this.topics);
                }
                this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.0f / (float)this.topics);
                this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 0.001);
            } else {
                this.topics = 10;
                this.alpha = 1.0f / (float)this.topics;
                this.delta = 0.001;
            }
            return cl;
        }

        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] parameters) throws HiveException {
            StructObjectInspector outputOI;
            assert (parameters.length == 1 || parameters.length == 4 || parameters.length == 5);
            super.init(mode, parameters);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.processOptions(parameters);
                this.wordOI = HiveUtils.asStringOI(parameters[0]);
                this.valueOI = HiveUtils.asDoubleCompatibleOI(parameters[1]);
                this.labelOI = HiveUtils.asIntegerOI(parameters[2]);
                this.lambdaOI = HiveUtils.asDoubleCompatibleOI(parameters[3]);
            } else {
                StructObjectInspector soi;
                this.internalMergeOI = soi = (StructObjectInspector)parameters[0];
                this.wcListField = soi.getStructFieldRef("wcList");
                this.lambdaMapField = soi.getStructFieldRef("lambdaMap");
                this.topicsOptionField = soi.getStructFieldRef("topics");
                this.alphaOptionField = soi.getStructFieldRef("alpha");
                this.deltaOptionField = soi.getStructFieldRef("delta");
                this.wcListElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
                this.wcListOI = ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)this.wcListElemOI);
                this.lambdaMapKeyOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
                this.lambdaMapValueElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
                this.lambdaMapValueOI = ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)this.lambdaMapValueElemOI);
                this.lambdaMapOI = ObjectInspectorFactory.getStandardMapObjectInspector((ObjectInspector)this.lambdaMapKeyOI, (ObjectInspector)this.lambdaMapValueOI);
            }
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2) {
                outputOI = Evaluator.internalMergeOI();
            } else {
                ArrayList<String> fieldNames = new ArrayList<String>();
                ArrayList<Object> fieldOIs = new ArrayList<Object>();
                fieldNames.add("label");
                fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
                fieldNames.add("probability");
                fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
                outputOI = ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs));
            }
            return outputOI;
        }

        private static StructObjectInspector internalMergeOI() {
            ArrayList<String> fieldNames = new ArrayList<String>();
            ArrayList<Object> fieldOIs = new ArrayList<Object>();
            fieldNames.add("wcList");
            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.javaStringObjectInspector));
            fieldNames.add("lambdaMap");
            fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.javaStringObjectInspector, (ObjectInspector)ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.javaFloatObjectInspector)));
            fieldNames.add("topics");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
            fieldNames.add("alpha");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
            fieldNames.add("delta");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
        }

        public GenericUDAFEvaluator.AggregationBuffer getNewAggregationBuffer() throws HiveException {
            OnlineLDAPredictAggregationBuffer myAggr = new OnlineLDAPredictAggregationBuffer();
            this.reset((GenericUDAFEvaluator.AggregationBuffer)myAggr);
            return myAggr;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer)agg;
            myAggr.reset();
            myAggr.setOptions(this.topics, this.alpha, this.delta);
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer agg, Object[] parameters) throws HiveException {
            OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer)agg;
            if (parameters[0] == null || parameters[1] == null || parameters[2] == null || parameters[3] == null) {
                return;
            }
            String word = PrimitiveObjectInspectorUtils.getString((Object)parameters[0], (PrimitiveObjectInspector)this.wordOI);
            float value = HiveUtils.getFloat(parameters[1], this.valueOI);
            int label = PrimitiveObjectInspectorUtils.getInt((Object)parameters[2], (PrimitiveObjectInspector)this.labelOI);
            float lambda = HiveUtils.getFloat(parameters[3], this.lambdaOI);
            myAggr.iterate(word, value, label, lambda);
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer)agg;
            if (myAggr.wcList.size() == 0) {
                return null;
            }
            Object[] partialResult = new Object[]{myAggr.wcList, myAggr.lambdaMap, new IntWritable(myAggr.topics), new FloatWritable(myAggr.alpha), new DoubleWritable(myAggr.delta)};
            return partialResult;
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer agg, Object partial) throws HiveException {
            if (partial == null) {
                return;
            }
            Object wcListObj = this.internalMergeOI.getStructFieldData(partial, this.wcListField);
            List wcListRaw = this.wcListOI.getList(HiveUtils.castLazyBinaryObject(wcListObj));
            int wcListSize = wcListRaw.size();
            ArrayList<String> wcList = new ArrayList<String>();
            for (int i = 0; i < wcListSize; ++i) {
                wcList.add(PrimitiveObjectInspectorUtils.getString(wcListRaw.get(i), (PrimitiveObjectInspector)this.wcListElemOI));
            }
            Object lambdaMapObj = this.internalMergeOI.getStructFieldData(partial, this.lambdaMapField);
            Map lambdaMapRaw = this.lambdaMapOI.getMap(HiveUtils.castLazyBinaryObject(lambdaMapObj));
            HashMap<String, List<Float>> lambdaMap = new HashMap<String, List<Float>>();
            for (Map.Entry e : lambdaMapRaw.entrySet()) {
                String word = PrimitiveObjectInspectorUtils.getString(e.getKey(), (PrimitiveObjectInspector)this.lambdaMapKeyOI);
                Object lambdaMapValueObj = e.getValue();
                List lambdaMapValueRaw = this.lambdaMapValueOI.getList(HiveUtils.castLazyBinaryObject(lambdaMapValueObj));
                int lambdaMapValueSize = lambdaMapValueRaw.size();
                ArrayList<Float> lambda_word = new ArrayList<Float>();
                for (int i = 0; i < lambdaMapValueSize; ++i) {
                    lambda_word.add(Float.valueOf(HiveUtils.getFloat(lambdaMapValueRaw.get(i), this.lambdaMapValueElemOI)));
                }
                lambdaMap.put(word, lambda_word);
            }
            Object topicsObj = this.internalMergeOI.getStructFieldData(partial, this.topicsOptionField);
            this.topics = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicsObj);
            Object alphaObj = this.internalMergeOI.getStructFieldData(partial, this.alphaOptionField);
            this.alpha = PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(alphaObj);
            Object deltaObj = this.internalMergeOI.getStructFieldData(partial, this.deltaOptionField);
            this.delta = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(deltaObj);
            OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer)agg;
            myAggr.setOptions(this.topics, this.alpha, this.delta);
            myAggr.merge(wcList, lambdaMap);
        }

        public Object terminate(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer)agg;
            float[] topicDistr = myAggr.get();
            KeySortablePair[] sorted = new KeySortablePair[topicDistr.length];
            for (int i = 0; i < topicDistr.length; ++i) {
                sorted[i] = new KeySortablePair<Float, Integer>(Float.valueOf(topicDistr[i]), i);
            }
            Arrays.sort(sorted, Collections.reverseOrder());
            ArrayList<Object[]> result = new ArrayList<Object[]>(sorted.length);
            for (KeySortablePair e : sorted) {
                Object[] struct = new Object[]{new IntWritable(((Integer)e.getValue()).intValue()), new FloatWritable(((Float)e.getKey()).floatValue())};
                result.add(struct);
            }
            return result;
        }
    }
}

