/*
 * Decompiled with CFR 0.152.
 */
package hivemall.factorization.mf;

import hivemall.UDTFWithOptions;
import hivemall.common.ConversionState;
import hivemall.factorization.mf.FactorizedModel;
import hivemall.factorization.mf.Rating;
import hivemall.factorization.mf.RatingInitializer;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.FileUtils;
import hivemall.utils.io.NioFixedSegment;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;

public abstract class OnlineMatrixFactorizationUDTF
extends UDTFWithOptions
implements RatingInitializer {
    private static final Log logger = LogFactory.getLog(OnlineMatrixFactorizationUDTF.class);
    private static final int RECORD_BYTES = 16;
    protected int factor = 10;
    protected float lambda = 0.03f;
    protected float meanRating = 0.0f;
    protected boolean updateMeanRating = false;
    protected int iterations = 1;
    protected boolean useBiasClause = true;
    protected FactorizedModel.RankInitScheme rankInit;
    protected FactorizedModel model;
    protected long count;
    protected ConversionState cvState;
    protected PrimitiveObjectInspector userOI;
    protected PrimitiveObjectInspector itemOI;
    protected PrimitiveObjectInspector ratingOI;
    protected NioFixedSegment fileIO;
    protected ByteBuffer inputBuf;
    private long lastWritePos;
    private float[] userProbe;
    private float[] itemProbe;

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("k", "factor", true, "The number of latent factor [default: 10]  Note this is alias for `factors` option.");
        opts.addOption("f", "factors", true, "The number of latent factor [default: 10]");
        opts.addOption("r", "lambda", true, "The regularization factor [default: 0.03]");
        opts.addOption("mu", "mean_rating", true, "The mean rating [default: 0.0]");
        opts.addOption("update_mean", "update_mu", false, "Whether update (and return) the mean rating or not");
        opts.addOption("rankinit", true, "Initialization strategy of rank matrix [random, gaussian] (default: random)");
        opts.addOption("maxval", "max_init_value", true, "The maximum initial value in the rank matrix [default: 1.0]");
        opts.addOption("min_init_stddev", true, "The minimum standard deviation of initial rank matrix [default: 0.1]");
        opts.addOption("iters", "iterations", true, "The number of iterations [default: 1]");
        opts.addOption("iter", true, "The number of iterations [default: 1] Alias for `-iterations`");
        opts.addOption("disable_cv", "disable_cvtest", false, "Whether to disable convergence check [default: enabled]");
        opts.addOption("cv_rate", "convergence_rate", true, "Threshold to determine convergence [default: 0.005]");
        opts.addOption("disable_bias", "no_bias", false, "Turn off bias clause");
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl = null;
        String rankInitOpt = null;
        float maxInitValue = 1.0f;
        double initStdDev = 0.1;
        boolean conversionCheck = true;
        double convergenceRate = 0.005;
        if (argOIs.length >= 4) {
            String rawArgs = HiveUtils.getConstString(argOIs, 3);
            cl = this.parseOptions(rawArgs);
            this.factor = cl.hasOption("factors") ? Primitives.parseInt(cl.getOptionValue("factors"), 10) : Primitives.parseInt(cl.getOptionValue("factor"), 10);
            this.lambda = Primitives.parseFloat(cl.getOptionValue("lambda"), 0.03f);
            this.meanRating = Primitives.parseFloat(cl.getOptionValue("mu"), 0.0f);
            this.updateMeanRating = cl.hasOption("update_mean");
            rankInitOpt = cl.getOptionValue("rankinit");
            maxInitValue = Primitives.parseFloat(cl.getOptionValue("max_init_value"), 1.0f);
            initStdDev = Primitives.parseDouble(cl.getOptionValue("min_init_stddev"), 0.1);
            this.iterations = cl.hasOption("iter") ? Primitives.parseInt(cl.getOptionValue("iter"), 1) : Primitives.parseInt(cl.getOptionValue("iterations"), 1);
            if (this.iterations < 1) {
                throw new UDFArgumentException("'-iterations' must be greater than or equal to 1: " + this.iterations);
            }
            conversionCheck = !cl.hasOption("disable_cvtest");
            convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate);
            boolean noBias = cl.hasOption("no_bias");
            boolean bl = this.useBiasClause = !noBias;
            if (noBias && this.updateMeanRating) {
                throw new UDFArgumentException("Cannot set both `update_mean` and `no_bias` option");
            }
        }
        this.rankInit = FactorizedModel.RankInitScheme.resolve(rankInitOpt);
        this.rankInit.setMaxInitValue(maxInitValue);
        initStdDev = Math.max(initStdDev, 1.0 / (double)this.factor);
        this.rankInit.setInitStdDev(initStdDev);
        this.cvState = new ConversionState(conversionCheck, convergenceRate);
        return cl;
    }

    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length < 3) {
            this.showHelp(String.format("%s takes 3 or more arguments: INT user, INT item, FLOAT rating [, CONSTANT STRING options]: %s", this.getClass().getSimpleName(), Arrays.toString(argOIs)));
        }
        this.userOI = HiveUtils.asIntCompatibleOI(argOIs, 0);
        this.itemOI = HiveUtils.asIntCompatibleOI(argOIs, 1);
        this.ratingOI = HiveUtils.asDoubleCompatibleOI(argOIs, 2);
        this.processOptions(argOIs);
        this.model = new FactorizedModel(this, this.factor, this.meanRating, this.rankInit);
        this.count = 0L;
        this.lastWritePos = 0L;
        this.userProbe = new float[this.factor];
        this.itemProbe = new float[this.factor];
        if (this.mapredContext != null && this.iterations > 1) {
            File file;
            try {
                file = File.createTempFile("hivemall_mf", ".sgmt");
                file.deleteOnExit();
                if (!file.canWrite()) {
                    throw new UDFArgumentException("Cannot write a temporary file: " + file.getAbsolutePath());
                }
            }
            catch (IOException ioe) {
                throw new UDFArgumentException((Throwable)ioe);
            }
            catch (Throwable e) {
                throw new UDFArgumentException(e);
            }
            this.fileIO = new NioFixedSegment(file, 16, false);
            this.inputBuf = ByteBuffer.allocateDirect(65536);
        }
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<Object> fieldOIs = new ArrayList<Object>();
        fieldNames.add("idx");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("Pu");
        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
        fieldNames.add("Qi");
        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
        if (this.useBiasClause) {
            fieldNames.add("Bu");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
            fieldNames.add("Bi");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
            if (this.updateMeanRating) {
                fieldNames.add("mu");
                fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
            }
        }
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Override
    public Rating newRating(float v) {
        return new Rating(v);
    }

    public final void process(Object[] args) throws HiveException {
        assert (args.length >= 3) : args.length;
        int user = PrimitiveObjectInspectorUtils.getInt((Object)args[0], (PrimitiveObjectInspector)this.userOI);
        if (user < 0) {
            throw new HiveException("Illegal user index: " + user);
        }
        int item = PrimitiveObjectInspectorUtils.getInt((Object)args[1], (PrimitiveObjectInspector)this.itemOI);
        if (item < 0) {
            throw new HiveException("Illegal item index: " + item);
        }
        double rating = PrimitiveObjectInspectorUtils.getDouble((Object)args[2], (PrimitiveObjectInspector)this.ratingOI);
        this.beforeTrain(this.count, user, item, rating);
        ++this.count;
        this.train(user, item, rating);
    }

    @Nonnull
    protected final float[] copyToUserProbe(@Nonnull Rating[] rating) {
        int size = this.factor;
        for (int k = 0; k < size; ++k) {
            this.userProbe[k] = rating[k].getWeight();
        }
        return this.userProbe;
    }

    @Nonnull
    protected final float[] copyToItemProbe(@Nonnull Rating[] rating) {
        int size = this.factor;
        for (int k = 0; k < size; ++k) {
            this.itemProbe[k] = rating[k].getWeight();
        }
        return this.itemProbe;
    }

    protected void train(int user, int item, double rating) throws HiveException {
        Rating[] users = this.model.getUserVector(user, true);
        assert (users != null);
        Rating[] items = this.model.getItemVector(item, true);
        assert (items != null);
        float[] userProbe = this.copyToUserProbe(users);
        float[] itemProbe = this.copyToItemProbe(items);
        double err = rating - this.predict(user, item, userProbe, itemProbe);
        this.cvState.incrError(Math.abs(err));
        this.cvState.incrLoss(err * err);
        float eta = this.eta();
        int size = this.factor;
        for (int k = 0; k < size; ++k) {
            float Pu = userProbe[k];
            float Qi = itemProbe[k];
            this.updateItemRating(items[k], Pu, Qi, err, eta);
            this.updateUserRating(users[k], Pu, Qi, err, eta);
        }
        if (this.useBiasClause) {
            this.updateBias(user, item, err, eta);
            if (this.updateMeanRating) {
                this.updateMeanRating(err, eta);
            }
        }
        this.onUpdate(user, item, users, items, err);
    }

    protected void beforeTrain(long rowNum, int user, int item, double rating) throws HiveException {
        if (this.inputBuf != null) {
            assert (this.fileIO != null);
            ByteBuffer buf = this.inputBuf;
            int remain = buf.remaining();
            if (remain < 16) {
                OnlineMatrixFactorizationUDTF.writeBuffer(buf, this.fileIO, this.lastWritePos);
                this.lastWritePos = rowNum;
            }
            buf.putInt(user);
            buf.putInt(item);
            buf.putDouble(rating);
        }
    }

    protected void onUpdate(int user, int item, Rating[] users, Rating[] items, double err) throws HiveException {
    }

    protected double predict(int user, int item, float[] userProbe, float[] itemProbe) {
        double ret = this.bias(user, item);
        int size = this.factor;
        for (int k = 0; k < size; ++k) {
            ret += (double)(userProbe[k] * itemProbe[k]);
        }
        return ret;
    }

    protected double predict(int user, int item) throws HiveException {
        Rating[] users = this.model.getUserVector(user);
        if (users == null) {
            throw new HiveException("User rating is not found: " + user);
        }
        Rating[] items = this.model.getItemVector(item);
        if (items == null) {
            throw new HiveException("Item rating is not found: " + item);
        }
        double ret = this.bias(user, item);
        int size = this.factor;
        for (int k = 0; k < size; ++k) {
            ret += (double)(users[k].getWeight() * items[k].getWeight());
        }
        return ret;
    }

    protected double bias(int user, int item) {
        if (!this.useBiasClause) {
            return this.model.getMeanRating();
        }
        return this.model.getMeanRating() + this.model.getUserBias(user) + this.model.getItemBias(item);
    }

    protected float eta() {
        return 1.0f;
    }

    protected void updateItemRating(Rating rating, float Pu, float Qi, double err, float eta) {
        double grad = err * (double)Pu - (double)(this.lambda * Qi);
        float newQi = Qi + (float)((double)eta * grad);
        rating.setWeight(newQi);
        this.cvState.incrLoss(this.lambda * Qi * Qi);
    }

    protected void updateUserRating(Rating rating, float Pu, float Qi, double err, float eta) {
        double grad = err * (double)Qi - (double)(this.lambda * Pu);
        float newPu = Pu + (float)((double)eta * grad);
        rating.setWeight(newPu);
        this.cvState.incrLoss(this.lambda * Pu * Pu);
    }

    protected void updateMeanRating(double err, float eta) {
        assert (this.updateMeanRating);
        float mean = this.model.getMeanRating();
        mean = (float)((double)mean + (double)eta * err);
        this.model.setMeanRating(mean);
    }

    protected void updateBias(int user, int item, double err, float eta) {
        assert (this.useBiasClause);
        float Bu = this.model.getUserBias(user);
        double Gu = err - (double)(this.lambda * Bu);
        Bu = (float)((double)Bu + (double)eta * Gu);
        this.model.setUserBias(user, Bu);
        this.cvState.incrLoss(this.lambda * Bu * Bu);
        float Bi = this.model.getItemBias(item);
        double Gi = err - (double)(this.lambda * Bi);
        Bi = (float)((double)Bi + (double)eta * Gi);
        this.model.setItemBias(item, Bi);
        this.cvState.incrLoss(this.lambda * Bi * Bi);
    }

    public void close() throws HiveException {
        if (this.model != null) {
            Object[] forwardObj;
            if (this.count == 0L) {
                this.model = null;
                return;
            }
            if (this.iterations > 1) {
                this.runIterativeTraining(this.iterations);
            }
            IntWritable idx = new IntWritable();
            FloatWritable[] Pu = HiveUtils.newFloatArray(this.factor, 0.0f);
            FloatWritable[] Qi = HiveUtils.newFloatArray(this.factor, 0.0f);
            FloatWritable Bu = new FloatWritable();
            FloatWritable Bi = new FloatWritable();
            if (this.updateMeanRating) {
                assert (this.useBiasClause);
                float meanRating = this.model.getMeanRating();
                FloatWritable mu = new FloatWritable(meanRating);
                forwardObj = new Object[]{idx, Pu, Qi, Bu, Bi, mu};
            } else {
                forwardObj = this.useBiasClause ? new Object[]{idx, Pu, Qi, Bu, Bi} : new Object[]{idx, Pu, Qi};
            }
            int numForwarded = 0;
            int maxIdx = this.model.getMaxIndex();
            for (int i = this.model.getMinIndex(); i <= maxIdx; ++i) {
                idx.set(i);
                Rating[] userRatings = this.model.getUserVector(i);
                if (userRatings == null) {
                    forwardObj[1] = null;
                } else {
                    forwardObj[1] = Pu;
                    OnlineMatrixFactorizationUDTF.copyTo(userRatings, Pu);
                }
                Rating[] itemRatings = this.model.getItemVector(i);
                if (itemRatings == null) {
                    forwardObj[2] = null;
                } else {
                    forwardObj[2] = Qi;
                    OnlineMatrixFactorizationUDTF.copyTo(itemRatings, Qi);
                }
                if (this.useBiasClause) {
                    Bu.set(this.model.getUserBias(i));
                    Bi.set(this.model.getItemBias(i));
                }
                this.forward(forwardObj);
                ++numForwarded;
            }
            this.model = null;
            logger.info((Object)("Forwarded the prediction model of " + numForwarded + " rows. [totalErrors=" + this.cvState.getTotalErrors() + ", lastLosses=" + this.cvState.getCumulativeLoss() + ", #trainingExamples=" + this.count + "]"));
        }
    }

    protected static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioFixedSegment dst, long lastWritePos) throws HiveException {
        srcBuf.flip();
        try {
            dst.writeRecords(lastWritePos, srcBuf);
        }
        catch (IOException e) {
            throw new HiveException("Exception causes while writing records to : " + lastWritePos, (Throwable)e);
        }
        srcBuf.clear();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected final void runIterativeTraining(@Nonnegative int iterations) throws HiveException {
        block27: {
            ByteBuffer inputBuf = this.inputBuf;
            NioFixedSegment fileIO = this.fileIO;
            assert (inputBuf != null);
            assert (fileIO != null);
            long numTrainingExamples = this.count;
            Reporter reporter = this.getReporter();
            Counters.Counter iterCounter = reporter == null ? null : reporter.getCounter("hivemall.factorization.mf.MatrixFactorization$Counter", "iteration");
            try {
                int iter;
                if (this.lastWritePos == 0L) {
                    if (inputBuf.position() == 0) {
                        return;
                    }
                    inputBuf.flip();
                    for (iter = 2; iter <= iterations; ++iter) {
                        this.cvState.next();
                        OnlineMatrixFactorizationUDTF.reportProgress(reporter);
                        OnlineMatrixFactorizationUDTF.setCounterValue(iterCounter, iter);
                        while (inputBuf.remaining() > 0) {
                            int user = inputBuf.getInt();
                            int item = inputBuf.getInt();
                            double rating = inputBuf.getDouble();
                            ++this.count;
                            this.train(user, item, rating);
                        }
                        this.cvState.multiplyLoss(0.5);
                        if (this.cvState.isConverged(numTrainingExamples)) break;
                        inputBuf.rewind();
                    }
                    logger.info((Object)("Performed " + this.cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(this.count) + " training updates in total) "));
                    break block27;
                }
                if (inputBuf.position() > 0) {
                    OnlineMatrixFactorizationUDTF.writeBuffer(inputBuf, fileIO, this.lastWritePos);
                }
                try {
                    fileIO.flush();
                }
                catch (IOException e) {
                    throw new HiveException("Failed to flush a file: " + fileIO.getFile().getAbsolutePath(), (Throwable)e);
                }
                if (logger.isInfoEnabled()) {
                    File tmpFile = fileIO.getFile();
                    logger.info((Object)("Wrote " + numTrainingExamples + " records to a temporary file for iterative training: " + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + ")"));
                }
                for (iter = 2; iter <= iterations; ++iter) {
                    this.cvState.next();
                    OnlineMatrixFactorizationUDTF.setCounterValue(iterCounter, iter);
                    inputBuf.clear();
                    long seekPos = 0L;
                    while (true) {
                        int remain;
                        int bytesRead;
                        OnlineMatrixFactorizationUDTF.reportProgress(reporter);
                        try {
                            bytesRead = fileIO.read(seekPos, inputBuf);
                        }
                        catch (IOException e) {
                            throw new HiveException("Failed to read a file: " + fileIO.getFile().getAbsolutePath(), (Throwable)e);
                        }
                        if (bytesRead == 0) break;
                        assert (bytesRead > 0) : bytesRead;
                        seekPos += (long)bytesRead;
                        inputBuf.flip();
                        assert (remain > 0) : remain;
                        for (remain = inputBuf.remaining(); remain >= 16; remain -= 16) {
                            int user = inputBuf.getInt();
                            int item = inputBuf.getInt();
                            double rating = inputBuf.getDouble();
                            ++this.count;
                            this.train(user, item, rating);
                        }
                        inputBuf.compact();
                    }
                    this.cvState.multiplyLoss(0.5);
                    if (this.cvState.isConverged(numTrainingExamples)) break;
                }
                logger.info((Object)("Performed " + this.cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples using a secondary storage (thus " + NumberUtils.formatNumber(this.count) + " training updates in total)"));
            }
            finally {
                try {
                    fileIO.close(true);
                }
                catch (IOException e) {
                    throw new HiveException("Failed to close a file: " + fileIO.getFile().getAbsolutePath(), (Throwable)e);
                }
                this.inputBuf = null;
                this.fileIO = null;
            }
        }
    }

    private static void copyTo(@Nonnull Rating[] rating, @Nonnull FloatWritable[] dst) {
        int size = rating.length;
        for (int k = 0; k < size; ++k) {
            float w = rating[k].getWeight();
            dst[k].set(w);
        }
    }
}

