/*
 * Decompiled with CFR 0.152.
 */
package hivemall.tools.map;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.StringUtils;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;

@Description(name="map_roulette", value="_FUNC_(Map<K, number> map [, (const) int/bigint seed]) - Returns a map key based on weighted random sampling of map values. Average of values is used for null values", extended="-- `map_roulette(map<key, number> [, integer seed])` returns key by weighted random selection\nSELECT \n  map_roulette(to_map(a, b)) -- 25% Tom, 21% Zhang, 54% Wang\nFROM ( -- see https://issues.apache.org/jira/browse/HIVE-17406\n  select 'Wang' as a, 54 as b\n  union all\n  select 'Zhang' as a, 21 as b\n  union all\n  select 'Tom' as a, 25 as b\n) tmp;\n> Wang\n\n-- Weight random selection with using filling nulls with the average value\nSELECT\n  map_roulette(map(1, 0.5, 'Wang', null)), -- 50% Wang, 50% 1\n  map_roulette(map(1, 0.5, 'Wang', null, 'Zhang', null)) -- 1/3 Wang, 1/3 1, 1/3 Zhang\n;\n\n-- NULL will be returned if every key is null\nSELECT \n  map_roulette(map()),\n  map_roulette(map(null, null, null, null));\n> NULL    NULL\n\n-- Return NULL if all weights are zero\nSELECT\n  map_roulette(map(1, 0)),\n  map_roulette(map(1, 0, '5', 0))\n;\n> NULL    NULL\n\n-- map_roulette does not support non-numeric weights or negative weights.\nSELECT map_roulette(map('Wong', 'A string', 'Zhao', 2));\n> HiveException: Error evaluating map_roulette(map('Wong':'A string','Zhao':2))\nSELECT map_roulette(map('Wong', 'A string', 'Zhao', 2));\n> UDFArgumentException: Map value must be greather than or equals to zero: -2")
@UDFType(deterministic=false, stateful=false)
public final class MapRouletteUDF
extends GenericUDF {
    private transient MapObjectInspector mapOI;
    private transient PrimitiveObjectInspector valueOI;
    @Nullable
    private transient PrimitiveObjectInspector seedOI;
    @Nullable
    private transient Random _rand;

    public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length != 1 && argOIs.length != 2) {
            throw new UDFArgumentLengthException("Expected exactly one argument for map_roulette: " + argOIs.length);
        }
        if (argOIs[0].getCategory() != ObjectInspector.Category.MAP) {
            throw new UDFArgumentTypeException(0, "Only map type argument is accepted but got " + argOIs[0].getTypeName());
        }
        this.mapOI = HiveUtils.asMapOI(argOIs[0]);
        this.valueOI = HiveUtils.asDoubleCompatibleOI(this.mapOI.getMapValueObjectInspector());
        if (argOIs.length == 2) {
            ObjectInspector argOI1 = argOIs[1];
            if (!HiveUtils.isIntegerOI(argOI1)) {
                throw new UDFArgumentException("The second argument of map_roulette must be integer type: " + argOI1.getTypeName());
            }
            if (ObjectInspectorUtils.isConstantObjectInspector((ObjectInspector)argOI1)) {
                long seed = HiveUtils.getAsConstLong(argOI1);
                this._rand = new Random(seed);
            } else {
                this.seedOI = HiveUtils.asLongCompatibleOI(argOI1);
            }
        } else {
            this._rand = new Random();
        }
        return this.mapOI.getMapKeyObjectInspector();
    }

    @Nullable
    public Object evaluate(GenericUDF.DeferredObject[] arguments) throws HiveException {
        Map<Object, Double> input;
        Random rand = this._rand;
        if (rand == null) {
            Object arg1 = arguments[1].get();
            if (arg1 == null) {
                rand = new Random();
            } else {
                long seed = HiveUtils.getLong(arg1, this.seedOI);
                rand = new Random(seed);
            }
        }
        if ((input = MapRouletteUDF.getObjectDoubleMap(arguments[0], this.mapOI, this.valueOI)) == null) {
            return null;
        }
        return MapRouletteUDF.rouletteWheelSelection(input, rand);
    }

    @Nullable
    private static Map<Object, Double> getObjectDoubleMap(@Nonnull GenericUDF.DeferredObject argument, @Nonnull MapObjectInspector mapOI, @Nonnull PrimitiveObjectInspector valueOI) throws HiveException {
        Map m = mapOI.getMap(argument.get());
        if (m == null) {
            return null;
        }
        int size = m.size();
        if (size == 0) {
            return null;
        }
        HashMap<Object, Double> result = new HashMap<Object, Double>(size);
        double sum = 0.0;
        int cnt = 0;
        for (Map.Entry entry : m.entrySet()) {
            Object value;
            Object key = entry.getKey();
            if (key == null || (value = entry.getValue()) == null) continue;
            double v = PrimitiveObjectInspectorUtils.convertPrimitiveToDouble(value, (PrimitiveObjectInspector)valueOI);
            if (v < 0.0) {
                throw new UDFArgumentException("Map value must be greather than or equals to zero: " + entry.getValue());
            }
            result.put(key, v);
            sum += v;
            ++cnt;
        }
        if (result.isEmpty()) {
            return null;
        }
        if (result.size() < m.size()) {
            Double avg = sum / (double)cnt;
            for (Map.Entry entry : m.entrySet()) {
                Object key = entry.getKey();
                if (key == null || entry.getValue() != null) continue;
                result.put(key, avg);
            }
        }
        return result;
    }

    @Nullable
    private static Object rouletteWheelSelection(@Nonnull Map<Object, Double> m, @Nonnull Random rnd) {
        Preconditions.checkArgument(!m.isEmpty());
        double sum = 0.0;
        for (Double v : m.values()) {
            sum += v.doubleValue();
        }
        double r = rnd.nextDouble() * sum;
        double s = 0.0;
        for (Map.Entry<Object, Double> e : m.entrySet()) {
            Object k = e.getKey();
            double v = e.getValue();
            if (!((s += v) > r)) continue;
            return k;
        }
        return null;
    }

    public String getDisplayString(String[] children) {
        return "map_roulette(" + StringUtils.join(children, ',') + ")";
    }
}

