/*
 * Decompiled with CFR 0.152.
 */
package hivemall.tools;

import hivemall.utils.collections.BoundedPriorityQueue;
import hivemall.utils.hadoop.HiveUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.IntWritable;

@Description(name="each_top_k", value="_FUNC_(int K, Object group, double cmpKey, *) - Returns top-K values (or tail-K values when k is less than 0)")
public final class EachTopKUDTF
extends GenericUDTF {
    private ObjectInspector[] argOIs;
    private PrimitiveObjectInspector kOI;
    private ObjectInspector prevGroupOI;
    private PrimitiveObjectInspector cmpKeyOI;
    private boolean _constantK;
    private int _prevK;
    private BoundedPriorityQueue<TupleWithKey> _queue;
    private TupleWithKey _tuple;
    private Object _previousGroup;

    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        int numArgs = argOIs.length;
        if (numArgs < 4) {
            throw new UDFArgumentException("each_top_k(int K, Object group, double cmpKey, *) takes at least 4 arguments: " + numArgs);
        }
        this.argOIs = argOIs;
        this._constantK = ObjectInspectorUtils.isConstantObjectInspector((ObjectInspector)argOIs[0]);
        if (this._constantK) {
            int k = HiveUtils.getAsConstInt(argOIs[0]);
            if (k == 0) {
                throw new UDFArgumentException("k should not be 0");
            }
            this._queue = EachTopKUDTF.getQueue(k);
        } else {
            this.kOI = HiveUtils.asIntCompatibleOI(argOIs[0]);
            this._prevK = 0;
        }
        this.prevGroupOI = ObjectInspectorUtils.getStandardObjectInspector((ObjectInspector)argOIs[1], (ObjectInspectorUtils.ObjectInspectorCopyOption)ObjectInspectorUtils.ObjectInspectorCopyOption.DEFAULT);
        this.cmpKeyOI = HiveUtils.asDoubleCompatibleOI(argOIs[2]);
        this._tuple = null;
        this._previousGroup = null;
        ArrayList<String> fieldNames = new ArrayList<String>(numArgs);
        ArrayList<Object> fieldOIs = new ArrayList<Object>(numArgs);
        fieldNames.add("rank");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("key");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
        for (int i = 3; i < numArgs; ++i) {
            fieldNames.add("c" + (i - 2));
            ObjectInspector rawOI = argOIs[i];
            ObjectInspector retOI = ObjectInspectorUtils.getStandardObjectInspector((ObjectInspector)rawOI, (ObjectInspectorUtils.ObjectInspectorCopyOption)ObjectInspectorUtils.ObjectInspectorCopyOption.DEFAULT);
            fieldOIs.add(retOI);
        }
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Nonnull
    private static BoundedPriorityQueue<TupleWithKey> getQueue(int k) {
        int sizeK = Math.abs(k);
        Comparator<TupleWithKey> comparator = k < 0 ? Collections.reverseOrder() : new Comparator<TupleWithKey>(){

            @Override
            public int compare(TupleWithKey o1, TupleWithKey o2) {
                return o1.compareTo(o2);
            }
        };
        return new BoundedPriorityQueue<TupleWithKey>(sizeK, comparator);
    }

    public void process(Object[] args) throws HiveException {
        Object[] row;
        Object arg1 = args[1];
        if (!this.isSameGroup(arg1)) {
            Object group;
            this._previousGroup = group = ObjectInspectorUtils.copyToStandardObject((Object)arg1, (ObjectInspector)this.argOIs[1], (ObjectInspectorUtils.ObjectInspectorCopyOption)ObjectInspectorUtils.ObjectInspectorCopyOption.DEFAULT);
            if (this._queue != null) {
                this.drainQueue();
            }
            if (!this._constantK) {
                int k = PrimitiveObjectInspectorUtils.getInt((Object)args[0], (PrimitiveObjectInspector)this.kOI);
                if (k == 0) {
                    return;
                }
                if (k != this._prevK) {
                    this._queue = EachTopKUDTF.getQueue(k);
                    this._prevK = k;
                }
            }
        }
        double key = PrimitiveObjectInspectorUtils.getDouble((Object)args[2], (PrimitiveObjectInspector)this.cmpKeyOI);
        TupleWithKey tuple = this._tuple;
        if (this._tuple == null) {
            row = new Object[args.length - 1];
            this._tuple = tuple = new TupleWithKey(key, row);
        } else {
            row = tuple.getRow();
            tuple.setKey(key);
        }
        for (int i = 3; i < args.length; ++i) {
            Object arg = args[i];
            ObjectInspector argOI = this.argOIs[i];
            row[i - 1] = ObjectInspectorUtils.copyToStandardObject((Object)arg, (ObjectInspector)argOI, (ObjectInspectorUtils.ObjectInspectorCopyOption)ObjectInspectorUtils.ObjectInspectorCopyOption.DEFAULT);
        }
        if (this._queue.offer(tuple)) {
            this._tuple = null;
        }
    }

    private boolean isSameGroup(Object arg1) {
        if (arg1 == null && this._previousGroup == null) {
            return true;
        }
        if (arg1 == null || this._previousGroup == null) {
            return false;
        }
        return ObjectInspectorUtils.compare((Object)arg1, (ObjectInspector)this.argOIs[1], (Object)this._previousGroup, (ObjectInspector)this.prevGroupOI) == 0;
    }

    private void drainQueue() throws HiveException {
        int queueSize = this._queue.size();
        if (queueSize > 0) {
            TupleWithKey[] tuples = new TupleWithKey[queueSize];
            for (int i = 0; i < queueSize; ++i) {
                TupleWithKey tuple = this._queue.poll();
                if (tuple == null) {
                    throw new IllegalStateException("Found null element in the queue");
                }
                tuples[i] = tuple;
            }
            IntWritable rankProbe = new IntWritable(-1);
            DoubleWritable keyProbe = new DoubleWritable(Double.NaN);
            int rank = 0;
            double lastKey = Double.NaN;
            for (int i = queueSize - 1; i >= 0; --i) {
                TupleWithKey tuple = tuples[i];
                tuples[i] = null;
                double key = tuple.getKey();
                if (key != lastKey) {
                    rankProbe.set(++rank);
                    keyProbe.set(key);
                    lastKey = key;
                }
                Object[] row = tuple.getRow();
                row[0] = rankProbe;
                row[1] = keyProbe;
                this.forward(row);
            }
            this._queue.clear();
        }
    }

    public void close() throws HiveException {
        this.drainQueue();
        this._queue = null;
        this._tuple = null;
    }

    private static final class TupleWithKey
    implements Comparable<TupleWithKey> {
        double key;
        Object[] row;

        TupleWithKey(double key, Object[] row) {
            this.key = key;
            this.row = row;
        }

        double getKey() {
            return this.key;
        }

        Object[] getRow() {
            return this.row;
        }

        void setKey(double key) {
            this.key = key;
        }

        @Override
        public int compareTo(TupleWithKey o) {
            return Double.compare(this.key, o.key);
        }
    }
}

