/*
 * Decompiled with CFR 0.152.
 */
package hivemall.recommend;

import hivemall.UDTFWithOptions;
import hivemall.annotations.VisibleForTesting;
import hivemall.common.ConversionState;
import hivemall.utils.collections.Fastutil;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.FileUtils;
import hivemall.utils.io.NioStatefulSegment;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.mutable.MutableDouble;
import hivemall.utils.lang.mutable.MutableInt;
import hivemall.utils.lang.mutable.MutableObject;
import it.unimi.dsi.fastutil.ints.Int2FloatMap;
import it.unimi.dsi.fastutil.ints.Int2FloatOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntIterator;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import matrix4j.matrix.FloatMatrix;
import matrix4j.matrix.sparse.floats.DoKFloatMatrix;
import matrix4j.vector.VectorProcedure;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;

@Description(name="train_slim", value="_FUNC_( int i, map<int, double> r_i, map<int, map<int, double>> topKRatesOfI, int j, map<int, double> r_j [, constant string options]) - Returns row index, column index and non-zero weight value of prediction model")
public class SlimUDTF
extends UDTFWithOptions {
    private static final Log logger = LogFactory.getLog(SlimUDTF.class);
    private PrimitiveObjectInspector itemIOI;
    private PrimitiveObjectInspector itemJOI;
    private MapObjectInspector riOI;
    private MapObjectInspector rjOI;
    private MapObjectInspector knnItemsOI;
    private PrimitiveObjectInspector knnItemsKeyOI;
    private MapObjectInspector knnItemsValueOI;
    private PrimitiveObjectInspector knnItemsValueKeyOI;
    private PrimitiveObjectInspector knnItemsValueValueOI;
    private PrimitiveObjectInspector riKeyOI;
    private PrimitiveObjectInspector riValueOI;
    private PrimitiveObjectInspector rjKeyOI;
    private PrimitiveObjectInspector rjValueOI;
    private double l1;
    private double l2;
    private int numIterations;
    private transient DoKFloatMatrix _weightMatrix;
    private int _previousItemId;
    @Nullable
    private transient Int2FloatMap _ri;
    @Nullable
    private transient Int2ObjectMap<Int2FloatMap> _kNNi;
    @Nullable
    private transient MutableInt _nnzKNNi;
    @Nullable
    private transient FloatMatrix _dataMatrix;
    private transient NioStatefulSegment _fileIO;
    private transient ByteBuffer _inputBuf;
    private ConversionState _cvState;
    private long _observedTrainingExamples;

    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        int numArgs = argOIs.length;
        if (numArgs == 1 && HiveUtils.isConstString(argOIs[0])) {
            String rawArgs = HiveUtils.getConstString(argOIs[0]);
            this.parseOptions(rawArgs);
        }
        if (numArgs != 5 && numArgs != 6) {
            throw new UDFArgumentException("_FUNC_ takes 5 or 6 arguments: int i, map<int, double> r_i, map<int, map<int, double>> topKRatesOfI, int j, map<int, double> r_j [, constant string options]: " + Arrays.toString(argOIs));
        }
        this.itemIOI = HiveUtils.asIntCompatibleOI(argOIs[0]);
        this.riOI = HiveUtils.asMapOI(argOIs[1]);
        this.riKeyOI = HiveUtils.asIntCompatibleOI(this.riOI.getMapKeyObjectInspector());
        this.riValueOI = HiveUtils.asPrimitiveObjectInspector(this.riOI.getMapValueObjectInspector());
        this.knnItemsOI = HiveUtils.asMapOI(argOIs[2]);
        this.knnItemsKeyOI = HiveUtils.asIntCompatibleOI(this.knnItemsOI.getMapKeyObjectInspector());
        this.knnItemsValueOI = HiveUtils.asMapOI(this.knnItemsOI.getMapValueObjectInspector());
        this.knnItemsValueKeyOI = HiveUtils.asIntCompatibleOI(this.knnItemsValueOI.getMapKeyObjectInspector());
        this.knnItemsValueValueOI = HiveUtils.asDoubleCompatibleOI(this.knnItemsValueOI.getMapValueObjectInspector());
        this.itemJOI = HiveUtils.asIntCompatibleOI(argOIs[3]);
        this.rjOI = HiveUtils.asMapOI(argOIs[4]);
        this.rjKeyOI = HiveUtils.asIntCompatibleOI(this.rjOI.getMapKeyObjectInspector());
        this.rjValueOI = HiveUtils.asPrimitiveObjectInspector(this.rjOI.getMapValueObjectInspector());
        this.processOptions(argOIs);
        this._observedTrainingExamples = 0L;
        this._previousItemId = Integer.MIN_VALUE;
        this._weightMatrix = null;
        this._dataMatrix = null;
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<Object> fieldOIs = new ArrayList<Object>();
        fieldNames.add("j");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("nn");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("w");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("l1", "l1coefficient", true, "Coefficient for l1 regularizer [default: 0.001]");
        opts.addOption("l2", "l2coefficient", true, "Coefficient for l2 regularizer [default: 0.0005]");
        opts.addOption("iters", "iterations", true, "The number of iterations for coordinate descent [default: 30]");
        opts.addOption("disable_cv", "disable_cvtest", false, "Whether to disable convergence check [default: enabled]");
        opts.addOption("cv_rate", "convergence_rate", true, "Threshold to determine convergence [default: 0.005]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl = null;
        double l1 = 0.001;
        double l2 = 5.0E-4;
        int numIterations = 30;
        boolean conversionCheck = true;
        double cv_rate = 0.005;
        if (argOIs.length >= 6) {
            String rawArgs = HiveUtils.getConstString(argOIs[5]);
            cl = this.parseOptions(rawArgs);
            l1 = Primitives.parseDouble(cl.getOptionValue("l1"), l1);
            if (l1 < 0.0) {
                throw new UDFArgumentException("Argument `double l1` must be non-negative: " + l1);
            }
            l2 = Primitives.parseDouble(cl.getOptionValue("l2"), l2);
            if (l2 < 0.0) {
                throw new UDFArgumentException("Argument `double l2` must be non-negative: " + l2);
            }
            numIterations = Primitives.parseInt(cl.getOptionValue("iters"), numIterations);
            if (numIterations <= 0) {
                throw new UDFArgumentException("Argument `int iters` must be greater than 0: " + numIterations);
            }
            conversionCheck = !cl.hasOption("disable_cvtest");
            cv_rate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), cv_rate);
            if (cv_rate <= 0.0) {
                throw new UDFArgumentException("Argument `double cv_rate` must be greater than 0.0: " + cv_rate);
            }
        }
        this.l1 = l1;
        this.l2 = l2;
        this.numIterations = numIterations;
        this._cvState = new ConversionState(conversionCheck, cv_rate);
        return cl;
    }

    public void process(@Nonnull Object[] args) throws HiveException {
        int itemI;
        if (this._weightMatrix == null) {
            this._weightMatrix = new DoKFloatMatrix();
            if (this.numIterations >= 2) {
                this._dataMatrix = new DoKFloatMatrix();
            }
            this._nnzKNNi = new MutableInt();
        }
        if ((itemI = PrimitiveObjectInspectorUtils.getInt((Object)args[0], (PrimitiveObjectInspector)this.itemIOI)) != this._previousItemId || this._ri == null) {
            this._ri = SlimUDTF.int2floatMap(itemI, this.riOI.getMap(args[1]), this.riKeyOI, this.riValueOI, this._dataMatrix, this._ri);
            this._kNNi = SlimUDTF.kNNentries(args[2], this.knnItemsOI, this.knnItemsKeyOI, this.knnItemsValueOI, this.knnItemsValueKeyOI, this.knnItemsValueValueOI, this._kNNi, this._nnzKNNi);
            int numKNNItems = this._nnzKNNi.getValue();
            if (this.numIterations >= 2 && numKNNItems >= 1) {
                this.recordTrainingInput(itemI, this._kNNi, numKNNItems);
            }
            this._previousItemId = itemI;
        }
        int itemJ = PrimitiveObjectInspectorUtils.getInt((Object)args[3], (PrimitiveObjectInspector)this.itemJOI);
        Int2FloatMap rj = SlimUDTF.int2floatMap(itemJ, this.rjOI.getMap(args[4]), this.rjKeyOI, this.rjValueOI, this._dataMatrix);
        this.train(itemI, this._ri, this._kNNi, itemJ, rj);
        ++this._observedTrainingExamples;
    }

    private void recordTrainingInput(int itemI, @Nonnull Int2ObjectMap<Int2FloatMap> knnItems, int numKNNItems) throws HiveException {
        ByteBuffer buf = this._inputBuf;
        NioStatefulSegment dst = this._fileIO;
        if (buf == null) {
            File file;
            try {
                file = File.createTempFile("hivemall_slim", ".sgmt");
                file.deleteOnExit();
                if (!file.canWrite()) {
                    throw new UDFArgumentException("Cannot write a temporary file: " + file.getAbsolutePath());
                }
            }
            catch (IOException ioe) {
                throw new UDFArgumentException((Throwable)ioe);
            }
            this._inputBuf = buf = ByteBuffer.allocateDirect(0x800000);
            this._fileIO = dst = new NioStatefulSegment(file, false);
        }
        int recordBytes = 8 + 8 * knnItems.size() + 8 * numKNNItems;
        int requiredBytes = 4 + recordBytes;
        int remain = buf.remaining();
        if (remain < requiredBytes) {
            SlimUDTF.writeBuffer(buf, dst);
        }
        buf.putInt(recordBytes);
        buf.putInt(itemI);
        buf.putInt(knnItems.size());
        for (Int2ObjectMap.Entry entry : Fastutil.fastIterable(knnItems)) {
            int user = entry.getIntKey();
            buf.putInt(user);
            Int2FloatMap ru = (Int2FloatMap)entry.getValue();
            buf.putInt(ru.size());
            for (Int2FloatMap.Entry e2 : Fastutil.fastIterable(ru)) {
                buf.putInt(e2.getIntKey());
                buf.putFloat(e2.getFloatValue());
            }
        }
    }

    private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefulSegment dst) throws HiveException {
        srcBuf.flip();
        try {
            dst.write(srcBuf);
        }
        catch (IOException e) {
            throw new HiveException("Exception causes while writing a buffer to file", (Throwable)e);
        }
        srcBuf.clear();
    }

    private void train(int itemI, @Nonnull Int2FloatMap ri, @Nonnull Int2ObjectMap<Int2FloatMap> kNNi, int itemJ, @Nonnull Int2FloatMap rj) {
        DoKFloatMatrix W = this._weightMatrix;
        int N = rj.size();
        if (N == 0) {
            return;
        }
        double gradSum = 0.0;
        double rateSum = 0.0;
        double lossSum = 0.0;
        for (Int2FloatMap.Entry e : Fastutil.fastIterable(rj)) {
            int user = e.getIntKey();
            double ruj = e.getFloatValue();
            double rui = ri.get(user);
            double eui = rui - SlimUDTF.predict(user, itemI, kNNi, itemJ, W);
            gradSum += ruj * eui;
            rateSum += ruj * ruj;
            lossSum += eui * eui;
        }
        double wij = W.get(itemI, itemJ, 0.0);
        double loss = lossSum / (double)N + 0.5 * this.l2 * wij * wij + this.l1 * wij;
        this._cvState.incrLoss(loss);
        W.set(itemI, itemJ, SlimUDTF.getUpdateTerm(gradSum /= (double)N, rateSum /= (double)N, this.l1, this.l2));
    }

    private void train(final int itemI, final @Nonnull Int2ObjectMap<Int2FloatMap> knnItems, final int itemJ) {
        final FloatMatrix A = this._dataMatrix;
        final DoKFloatMatrix W = this._weightMatrix;
        int N = A.numColumns(itemJ);
        if (N == 0) {
            return;
        }
        final MutableDouble mutableGradSum = new MutableDouble(0.0);
        final MutableDouble mutableRateSum = new MutableDouble(0.0);
        final MutableDouble mutableLossSum = new MutableDouble(0.0);
        A.eachNonZeroInRow(itemJ, new VectorProcedure(){

            @Override
            public void apply(int user, double ruj) {
                double rui = A.get(itemI, user, 0.0);
                double eui = rui - SlimUDTF.predict(user, itemI, knnItems, itemJ, W);
                mutableGradSum.addValue(ruj * eui);
                mutableRateSum.addValue(ruj * ruj);
                mutableLossSum.addValue(eui * eui);
            }
        });
        double gradSum = mutableGradSum.getValue() / (double)N;
        double rateSum = mutableRateSum.getValue() / (double)N;
        double wij = W.get(itemI, itemJ, 0.0);
        double loss = mutableLossSum.getValue() / (double)N + 0.5 * this.l2 * wij * wij + this.l1 * wij;
        this._cvState.incrLoss(loss);
        W.set(itemI, itemJ, SlimUDTF.getUpdateTerm(gradSum, rateSum, this.l1, this.l2));
    }

    private static double predict(int user, int itemI, @Nonnull Int2ObjectMap<Int2FloatMap> knnItems, int excludeIndex, @Nonnull FloatMatrix weightMatrix) {
        Int2FloatMap kNNu = (Int2FloatMap)knnItems.get(user);
        if (kNNu == null) {
            return 0.0;
        }
        double pred = 0.0;
        for (Int2FloatMap.Entry e : Fastutil.fastIterable(kNNu)) {
            int itemK = e.getIntKey();
            if (itemK == excludeIndex) continue;
            float ruk = e.getFloatValue();
            pred += (double)ruk * weightMatrix.get(itemI, itemK, 0.0);
        }
        return pred;
    }

    private static double getUpdateTerm(double gradSum, double rateSum, double l1, double l2) {
        double update = 0.0;
        if (Math.abs(gradSum) > l1 && (update = gradSum > 0.0 ? (gradSum - l1) / (rateSum + l2) : (gradSum + l1) / (rateSum + l2)) < 0.0) {
            update = 0.0;
        }
        return update;
    }

    public void close() throws HiveException {
        this.finalizeTraining();
        this.forwardModel();
        this._weightMatrix = null;
    }

    @VisibleForTesting
    void finalizeTraining() throws HiveException {
        if (this.numIterations > 1) {
            this._ri = null;
            this._kNNi = null;
            this.runIterativeTraining();
            this._dataMatrix = null;
        }
    }

    private void runIterativeTraining() throws HiveException {
        block31: {
            ByteBuffer buf = this._inputBuf;
            NioStatefulSegment dst = this._fileIO;
            assert (buf != null);
            assert (dst != null);
            Reporter reporter = this.getReporter();
            Counters.Counter iterCounter = reporter == null ? null : reporter.getCounter("hivemall.recommend.slim$Counter", "iteration");
            try {
                int iter;
                if (dst.getPosition() == 0L) {
                    if (buf.position() == 0) {
                        return;
                    }
                    buf.flip();
                    for (iter = 2; iter < this.numIterations; ++iter) {
                        this._cvState.next();
                        SlimUDTF.reportProgress(reporter);
                        SlimUDTF.setCounterValue(iterCounter, iter);
                        while (buf.remaining() > 0) {
                            int recordBytes = buf.getInt();
                            assert (recordBytes > 0) : recordBytes;
                            this.replayTrain(buf);
                        }
                        buf.rewind();
                        if (this._cvState.isConverged(this._observedTrainingExamples)) break;
                    }
                    logger.info((Object)("Performed " + this._cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(this._observedTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(this._observedTrainingExamples * (long)this._cvState.getCurrentIteration()) + " training updates in total) "));
                    break block31;
                }
                if (buf.remaining() > 0) {
                    SlimUDTF.writeBuffer(buf, dst);
                }
                try {
                    dst.flush();
                }
                catch (IOException e) {
                    throw new HiveException("Failed to flush a file: " + dst.getFile().getAbsolutePath(), (Throwable)e);
                }
                if (logger.isInfoEnabled()) {
                    File tmpFile = dst.getFile();
                    logger.info((Object)("Wrote KNN entries of axis items to a temporary file for iterative training: " + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + ")"));
                }
                for (iter = 2; iter < this.numIterations; ++iter) {
                    this._cvState.next();
                    SlimUDTF.setCounterValue(iterCounter, iter);
                    buf.clear();
                    dst.resetPosition();
                    while (true) {
                        int recordBytes;
                        int remain;
                        int bytesRead;
                        SlimUDTF.reportProgress(reporter);
                        try {
                            bytesRead = dst.read(buf);
                        }
                        catch (IOException e) {
                            throw new HiveException("Failed to read a file: " + dst.getFile().getAbsolutePath(), (Throwable)e);
                        }
                        if (bytesRead == 0) break;
                        assert (bytesRead > 0) : bytesRead;
                        buf.flip();
                        if (remain < 4) {
                            throw new HiveException("Illegal file format was detected");
                        }
                        for (remain = buf.remaining(); remain >= 4; remain -= recordBytes) {
                            int pos = buf.position();
                            recordBytes = buf.getInt();
                            if ((remain -= 4) < recordBytes) {
                                buf.position(pos);
                                break;
                            }
                            this.replayTrain(buf);
                        }
                        buf.compact();
                    }
                    if (this._cvState.isConverged(this._observedTrainingExamples)) break;
                }
                logger.info((Object)("Performed " + this._cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(this._observedTrainingExamples) + " training examples on memory and KNNi data on secondary storage (thus " + NumberUtils.formatNumber(this._observedTrainingExamples * (long)this._cvState.getCurrentIteration()) + " training updates in total) "));
            }
            catch (Throwable e) {
                throw new HiveException("Exception caused in the iterative training", e);
            }
            finally {
                try {
                    dst.close(true);
                }
                catch (IOException e) {
                    throw new HiveException("Failed to close a file: " + dst.getFile().getAbsolutePath(), (Throwable)e);
                }
                this._inputBuf = null;
                this._fileIO = null;
            }
        }
    }

    private void replayTrain(@Nonnull ByteBuffer buf) {
        int itemI = buf.getInt();
        int knnSize = buf.getInt();
        Int2ObjectOpenHashMap<Int2FloatMap> knnItems = new Int2ObjectOpenHashMap<Int2FloatMap>(1024);
        IntOpenHashSet pairItems = new IntOpenHashSet();
        for (int i = 0; i < knnSize; ++i) {
            int user = buf.getInt();
            int ruSize = buf.getInt();
            Int2FloatOpenHashMap ru = new Int2FloatOpenHashMap(ruSize);
            ru.defaultReturnValue(0.0f);
            for (int j = 0; j < ruSize; ++j) {
                int itemK = buf.getInt();
                pairItems.add(itemK);
                float ruk = buf.getFloat();
                ru.put(itemK, ruk);
            }
            knnItems.put(user, (Int2FloatMap)ru);
        }
        IntIterator intIterator = pairItems.iterator();
        while (intIterator.hasNext()) {
            int itemJ = (Integer)intIterator.next();
            this.train(itemI, knnItems, itemJ);
        }
    }

    private void forwardModel() throws HiveException {
        final IntWritable f0 = new IntWritable();
        final IntWritable f1 = new IntWritable();
        final FloatWritable f2 = new FloatWritable();
        final Object[] forwardObj = new Object[]{f0, f1, f2};
        final MutableObject catched = new MutableObject();
        this._weightMatrix.eachNonZeroCell(new VectorProcedure(){

            @Override
            public void apply(int i, int j, float value) {
                if (value == 0.0f) {
                    return;
                }
                f0.set(i);
                f1.set(j);
                f2.set(value);
                try {
                    SlimUDTF.this.forward(forwardObj);
                }
                catch (HiveException e) {
                    catched.setIfAbsent(e);
                }
            }
        });
        HiveException ex = (HiveException)((Object)catched.get());
        if (ex != null) {
            throw ex;
        }
        logger.info((Object)"Forwarded SLIM's weights matrix");
    }

    @Nonnull
    private static Int2ObjectMap<Int2FloatMap> kNNentries(@Nonnull Object kNNiObj, @Nonnull MapObjectInspector knnItemsOI, @Nonnull PrimitiveObjectInspector knnItemsKeyOI, @Nonnull MapObjectInspector knnItemsValueOI, @Nonnull PrimitiveObjectInspector knnItemsValueKeyOI, @Nonnull PrimitiveObjectInspector knnItemsValueValueOI, @Nullable Int2ObjectMap<Int2FloatMap> knnItems, @Nonnull MutableInt nnzKNNi) {
        if (knnItems == null) {
            knnItems = new Int2ObjectOpenHashMap<Int2FloatMap>(1024);
        } else {
            knnItems.clear();
        }
        int numElementOfKNNItems = 0;
        for (Map.Entry entry : knnItemsOI.getMap(kNNiObj).entrySet()) {
            int user = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), (PrimitiveObjectInspector)knnItemsKeyOI);
            Int2FloatMap ru = SlimUDTF.int2floatMap(knnItemsValueOI.getMap(entry.getValue()), knnItemsValueKeyOI, knnItemsValueValueOI);
            knnItems.put(user, ru);
            numElementOfKNNItems += ru.size();
        }
        nnzKNNi.setValue(numElementOfKNNItems);
        return knnItems;
    }

    @Nonnull
    private static Int2FloatMap int2floatMap(@Nonnull Map<?, ?> map, @Nonnull PrimitiveObjectInspector keyOI, @Nonnull PrimitiveObjectInspector valueOI) {
        Int2FloatOpenHashMap result = new Int2FloatOpenHashMap(map.size());
        result.defaultReturnValue(0.0f);
        for (Map.Entry<?, ?> entry : map.entrySet()) {
            float v = PrimitiveObjectInspectorUtils.getFloat(entry.getValue(), (PrimitiveObjectInspector)valueOI);
            if (v == 0.0f) continue;
            int k = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), (PrimitiveObjectInspector)keyOI);
            result.put(k, v);
        }
        return result;
    }

    @Nonnull
    private static Int2FloatMap int2floatMap(int item, @Nonnull Map<?, ?> map, @Nonnull PrimitiveObjectInspector keyOI, @Nonnull PrimitiveObjectInspector valueOI, @Nullable FloatMatrix dataMatrix) {
        return SlimUDTF.int2floatMap(item, map, keyOI, valueOI, dataMatrix, null);
    }

    @Nonnull
    private static Int2FloatMap int2floatMap(int item, @Nonnull Map<?, ?> map, @Nonnull PrimitiveObjectInspector keyOI, @Nonnull PrimitiveObjectInspector valueOI, @Nullable FloatMatrix dataMatrix, @Nullable Int2FloatMap dst) {
        if (dst == null) {
            dst = new Int2FloatOpenHashMap(map.size());
            dst.defaultReturnValue(0.0f);
        } else {
            dst.clear();
        }
        for (Map.Entry<?, ?> entry : map.entrySet()) {
            float rating = PrimitiveObjectInspectorUtils.getFloat(entry.getValue(), (PrimitiveObjectInspector)valueOI);
            if (rating == 0.0f) continue;
            int user = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), (PrimitiveObjectInspector)keyOI);
            dst.put(user, rating);
            if (dataMatrix == null) continue;
            dataMatrix.set(item, user, rating);
        }
        return dst;
    }
}

