/*
 * Decompiled with CFR 0.152.
 */
package hivemall.ftvec.trans;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Identifier;
import hivemall.utils.lang.Preconditions;
import java.util.ArrayList;
import java.util.Map;
import java.util.Set;
import javax.annotation.CheckForNull;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.Writable;

@Description(name="onehot_encoding", value="_FUNC_(PRIMITIVE feature, ...) - Compute onehot encoded label for each feature", extended="WITH mapping as (\n  select \n    m.f1, m.f2 \n  from (\n    select onehot_encoding(species, category) m\n    from test\n  ) tmp\n)\nselect\n  array(m.f1[t.species],m.f2[t.category],feature('count',count)) as sparse_features\nfrom\n  test t\n  CROSS JOIN mapping m;\n\n[\"2\",\"8\",\"count:9\"]\n[\"5\",\"8\",\"count:10\"]\n[\"1\",\"6\",\"count:101\"]")
@UDFType(deterministic=true, stateful=true)
public final class OnehotEncodingUDAF
extends AbstractGenericUDAFResolver {
    public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] argTypes) throws SemanticException {
        int numFeatures = argTypes.length;
        if (numFeatures == 0) {
            throw new UDFArgumentException("_FUNC_ requires at least 1 argument");
        }
        for (int i = 0; i < numFeatures; ++i) {
            if (argTypes[i] == null) {
                throw new UDFArgumentTypeException(i, "Null type is found. Only primitive type arguments are accepted.");
            }
            if (argTypes[i].getCategory() == ObjectInspector.Category.PRIMITIVE) continue;
            throw new UDFArgumentTypeException(i, "Only primitive type arguments are accepted but " + argTypes[i].getTypeName() + " was passed as parameter 1.");
        }
        return new GenericUDAFOnehotEncodingEvaluator();
    }

    public static final class EncodingBuffer
    extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        @Nullable
        private Identifier<Writable>[] identifiers;

        void reset() {
            this.identifiers = null;
        }

        void iterate(@Nonnull Object[] args, @Nonnull PrimitiveObjectInspector[] inputOIs) throws HiveException {
            int i;
            Preconditions.checkArgument(args.length == inputOIs.length);
            int length = args.length;
            if (this.identifiers == null) {
                this.identifiers = new Identifier[length];
                for (i = 0; i < length; ++i) {
                    this.identifiers[i] = new Identifier(1);
                }
            }
            for (i = 0; i < length; ++i) {
                Object arg = args[i];
                if (arg == null) continue;
                Writable writable = WritableUtils.copyToWritable(arg, inputOIs[i]);
                this.identifiers[i].put(writable);
            }
        }

        @Nullable
        Object[] partial() throws HiveException {
            if (this.identifiers == null) {
                return null;
            }
            int length = this.identifiers.length;
            Object[] partial = new Object[length];
            for (int i = 0; i < length; ++i) {
                Set<Writable> id = this.identifiers[i].getMap().keySet();
                ArrayList<Writable> list = new ArrayList<Writable>(id.size());
                for (Writable e : id) {
                    Preconditions.checkNotNull(e);
                    list.add(e);
                }
                partial[i] = list;
            }
            return partial;
        }

        void merge(@Nonnull Object partial, @Nonnull StructObjectInspector mergeOI, @Nonnull StructField[] fields, @Nonnull ListObjectInspector[] fieldOIs) {
            Preconditions.checkArgument(fields.length == fieldOIs.length);
            int numFields = fieldOIs.length;
            if (this.identifiers == null) {
                this.identifiers = new Identifier[numFields];
            }
            Preconditions.checkArgument(fields.length == this.identifiers.length);
            for (int i = 0; i < numFields; ++i) {
                Identifier<Object> id = this.identifiers[i];
                if (id == null) {
                    id = new Identifier(1);
                    this.identifiers[i] = id;
                }
                Object fieldData = mergeOI.getStructFieldData(partial, fields[i]);
                ListObjectInspector fieldOI = fieldOIs[i];
                int size = fieldOI.getListLength(fieldData);
                for (int j = 0; j < size; ++j) {
                    Object o = fieldOI.getListElement(fieldData, j);
                    Preconditions.checkNotNull(o);
                    id.valueOf((Writable)o);
                }
            }
        }

        @Nullable
        Object[] terminate() {
            if (this.identifiers == null) {
                return null;
            }
            Object[] ret = new Object[this.identifiers.length];
            int max = 0;
            for (int i = 0; i < this.identifiers.length; ++i) {
                Map<Writable, Integer> m = this.identifiers[i].getMap();
                if (max != 0) {
                    for (Map.Entry<Writable, Integer> e : m.entrySet()) {
                        int original = e.getValue();
                        e.setValue(max + original);
                    }
                }
                ret[i] = m;
                max += m.size();
            }
            return ret;
        }
    }

    public static final class GenericUDAFOnehotEncodingEvaluator
    extends GenericUDAFEvaluator {
        private PrimitiveObjectInspector[] inputElemOIs;
        private StructObjectInspector mergeOI;
        private StructField[] fields;
        private ListObjectInspector[] fieldOIs;

        public ObjectInspector init(GenericUDAFEvaluator.Mode m, ObjectInspector[] argOIs) throws HiveException {
            StructObjectInspector outputOI;
            super.init(m, argOIs);
            if (m == GenericUDAFEvaluator.Mode.PARTIAL1 || m == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.inputElemOIs = new PrimitiveObjectInspector[argOIs.length];
                for (int i = 0; i < argOIs.length; ++i) {
                    this.inputElemOIs[i] = HiveUtils.asPrimitiveObjectInspector(argOIs[i]);
                }
            } else {
                Preconditions.checkArgument(argOIs.length == 1);
                this.mergeOI = HiveUtils.asStructOI(argOIs[0]);
                int numFields = this.mergeOI.getAllStructFieldRefs().size();
                this.fields = new StructField[numFields];
                this.fieldOIs = new ListObjectInspector[numFields];
                this.inputElemOIs = new PrimitiveObjectInspector[numFields];
                for (int i = 0; i < numFields; ++i) {
                    ListObjectInspector fieldOI;
                    StructField field;
                    this.fields[i] = field = this.mergeOI.getStructFieldRef("f" + String.valueOf(i));
                    this.fieldOIs[i] = fieldOI = HiveUtils.asListOI(field.getFieldObjectInspector());
                    this.inputElemOIs[i] = HiveUtils.asPrimitiveObjectInspector(fieldOI.getListElementObjectInspector());
                }
            }
            switch (m) {
                case PARTIAL1: {
                    outputOI = GenericUDAFOnehotEncodingEvaluator.internalMergeOutputOI(this.inputElemOIs);
                    break;
                }
                case PARTIAL2: {
                    outputOI = GenericUDAFOnehotEncodingEvaluator.internalMergeOutputOI(this.inputElemOIs);
                    break;
                }
                case COMPLETE: {
                    outputOI = GenericUDAFOnehotEncodingEvaluator.terminalOutputOI(this.inputElemOIs);
                    break;
                }
                case FINAL: {
                    outputOI = GenericUDAFOnehotEncodingEvaluator.terminalOutputOI(this.inputElemOIs);
                    break;
                }
                default: {
                    throw new IllegalStateException("Illegal mode: " + m);
                }
            }
            return outputOI;
        }

        @Nonnull
        private static StructObjectInspector internalMergeOutputOI(@CheckForNull PrimitiveObjectInspector[] inputOIs) throws UDFArgumentException {
            Preconditions.checkNotNull(inputOIs);
            int numOIs = inputOIs.length;
            ArrayList<String> fieldNames = new ArrayList<String>(numOIs);
            ArrayList<StandardListObjectInspector> fieldOIs = new ArrayList<StandardListObjectInspector>(numOIs);
            for (int i = 0; i < numOIs; ++i) {
                fieldNames.add("f" + String.valueOf(i));
                ObjectInspector elemOI = ObjectInspectorUtils.getStandardObjectInspector((ObjectInspector)inputOIs[i], (ObjectInspectorUtils.ObjectInspectorCopyOption)ObjectInspectorUtils.ObjectInspectorCopyOption.WRITABLE);
                StandardListObjectInspector listOI = ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)elemOI);
                fieldOIs.add(listOI);
            }
            return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
        }

        @Nonnull
        private static StructObjectInspector terminalOutputOI(@CheckForNull PrimitiveObjectInspector[] inputOIs) {
            Preconditions.checkNotNull(inputOIs);
            Preconditions.checkArgument(inputOIs.length >= 1, inputOIs.length);
            ArrayList<String> fieldNames = new ArrayList<String>(inputOIs.length);
            ArrayList<StandardMapObjectInspector> fieldOIs = new ArrayList<StandardMapObjectInspector>(inputOIs.length);
            for (int i = 0; i < inputOIs.length; ++i) {
                fieldNames.add("f" + String.valueOf(i + 1));
                ObjectInspector keyOI = ObjectInspectorUtils.getStandardObjectInspector((ObjectInspector)inputOIs[i], (ObjectInspectorUtils.ObjectInspectorCopyOption)ObjectInspectorUtils.ObjectInspectorCopyOption.WRITABLE);
                StandardMapObjectInspector mapOI = ObjectInspectorFactory.getStandardMapObjectInspector((ObjectInspector)keyOI, (ObjectInspector)PrimitiveObjectInspectorFactory.javaIntObjectInspector);
                fieldOIs.add(mapOI);
            }
            return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
        }

        public GenericUDAFEvaluator.AggregationBuffer getNewAggregationBuffer() throws HiveException {
            EncodingBuffer buf = new EncodingBuffer();
            this.reset((GenericUDAFEvaluator.AggregationBuffer)buf);
            return buf;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            EncodingBuffer buf = (EncodingBuffer)aggregationBuffer;
            buf.reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object[] parameters) throws HiveException {
            Preconditions.checkNotNull(this.inputElemOIs);
            EncodingBuffer buf = (EncodingBuffer)aggregationBuffer;
            buf.iterate(parameters, this.inputElemOIs);
        }

        public Object[] terminatePartial(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            EncodingBuffer buf = (EncodingBuffer)aggregationBuffer;
            return buf.partial();
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object partial) throws HiveException {
            if (partial == null) {
                return;
            }
            EncodingBuffer buf = (EncodingBuffer)aggregationBuffer;
            buf.merge(partial, this.mergeOI, this.fields, this.fieldOIs);
        }

        public Object[] terminate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            EncodingBuffer buf = (EncodingBuffer)aggregationBuffer;
            return buf.terminate();
        }
    }
}

