/*
 * Decompiled with CFR 0.152.
 */
package hivemall.factorization.mf;

import hivemall.UDTFWithOptions;
import hivemall.common.ConversionState;
import hivemall.factorization.mf.FactorizedModel;
import hivemall.factorization.mf.Rating;
import hivemall.factorization.mf.RatingInitializer;
import hivemall.optimizer.EtaEstimator;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.FileUtils;
import hivemall.utils.io.NioFixedSegment;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.math.MathUtils;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;

@Description(name="train_bprmf", value="_FUNC_(INT user, INT posItem, INT negItem [, String options]) - Returns a relation <INT i, FLOAT Pi, FLOAT Qi [, FLOAT Bi]>")
public final class BPRMatrixFactorizationUDTF
extends UDTFWithOptions
implements RatingInitializer {
    private static final Log LOG = LogFactory.getLog(BPRMatrixFactorizationUDTF.class);
    private static final int RECORD_BYTES = 12;
    protected int factor = 10;
    protected float regU = 0.0025f;
    protected float regI = 0.0025f;
    protected float regJ = 0.00125f;
    protected float regBias = 0.01f;
    protected boolean useBiasClause = true;
    protected int iterations = 30;
    protected LossFunction lossFunction;
    protected FactorizedModel.RankInitScheme rankInit;
    protected EtaEstimator etaEstimator;
    protected long count;
    protected ConversionState cvState;
    protected FactorizedModel model;
    protected PrimitiveObjectInspector userOI;
    protected PrimitiveObjectInspector posItemOI;
    protected PrimitiveObjectInspector negItemOI;
    protected NioFixedSegment fileIO;
    protected ByteBuffer inputBuf;
    private long lastWritePos;
    private float[] uProbe;
    private float[] iProbe;
    private float[] jProbe;

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("k", "factor", true, "The number of latent factor [default: 10] Alias for `-factors`");
        opts.addOption("f", "factors", true, "The number of latent factor [default: 10]");
        opts.addOption("iters", "iterations", true, "The number of iterations [default: 30]");
        opts.addOption("iter", true, "The number of iterations [default: 30] Alias for `-iterations");
        opts.addOption("loss", "loss_function", true, "Loss function [default: lnLogistic, logistic, sigmoid]");
        opts.addOption("rankinit", true, "Initialization strategy of rank matrix [random, gaussian] (default: random)");
        opts.addOption("maxval", "max_init_value", true, "The maximum initial value in the rank matrix [default: 1.0]");
        opts.addOption("min_init_stddev", true, "The minimum standard deviation of initial rank matrix [default: 0.1]");
        opts.addOption("reg", "lambda", true, "The regularization factor [default: 0.0025]");
        opts.addOption("reg_u", "reg_user", true, "The regularization factor for user [default: 0.0025 (reg)]");
        opts.addOption("reg_i", "reg_item", true, "The regularization factor for positive item [default: 0.0025 (reg)]");
        opts.addOption("reg_j", true, "The regularization factor for negative item [default: 0.00125 (reg_i/2) ]");
        opts.addOption("reg_bias", true, "The regularization factor for bias clause [default: 0.01]");
        opts.addOption("disable_bias", "no_bias", false, "Turn off bias clause");
        opts.addOption("eta", true, "The initial learning rate [default: 0.3]");
        opts.addOption("eta0", true, "The initial learning rate [default: 0.1]");
        opts.addOption("t", "total_steps", true, "The total number of training examples");
        opts.addOption("power_t", true, "The exponent for inverse scaling learning rate [default: 0.1]");
        opts.addOption("boldDriver", "bold_driver", false, "Whether to use Bold Driver for learning rate [default: false]");
        opts.addOption("disable_cv", "disable_cvtest", false, "Whether to disable convergence check [default: enabled]");
        opts.addOption("cv_rate", "convergence_rate", true, "Threshold to determine convergence [default: 0.005]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl = null;
        String lossFuncName = null;
        String rankInitOpt = null;
        float maxInitValue = 1.0f;
        double initStdDev = 0.1;
        boolean conversionCheck = true;
        double convergenceRate = 0.005;
        if (argOIs.length >= 4) {
            String rawArgs = HiveUtils.getConstString(argOIs, 3);
            cl = this.parseOptions(rawArgs);
            this.factor = cl.hasOption("factor") ? Primitives.parseInt(cl.getOptionValue("factor"), this.factor) : Primitives.parseInt(cl.getOptionValue("factors"), this.factor);
            this.iterations = cl.hasOption("iter") ? Primitives.parseInt(cl.getOptionValue("iter"), this.iterations) : Primitives.parseInt(cl.getOptionValue("iterations"), this.iterations);
            if (this.iterations < 1) {
                throw new UDFArgumentException("'-iterations' must be greater than or equals to 1: " + this.iterations);
            }
            lossFuncName = cl.getOptionValue("loss_function");
            float reg = Primitives.parseFloat(cl.getOptionValue("reg"), 0.0025f);
            this.regU = Primitives.parseFloat(cl.getOptionValue("reg_u"), reg);
            this.regI = Primitives.parseFloat(cl.getOptionValue("reg_i"), reg);
            this.regJ = Primitives.parseFloat(cl.getOptionValue("reg_j"), this.regI / 2.0f);
            this.regBias = Primitives.parseFloat(cl.getOptionValue("reg_bias"), this.regBias);
            rankInitOpt = cl.getOptionValue("rankinit");
            maxInitValue = Primitives.parseFloat(cl.getOptionValue("max_init_value"), 1.0f);
            initStdDev = Primitives.parseDouble(cl.getOptionValue("min_init_stddev"), 0.1);
            conversionCheck = !cl.hasOption("disable_cvtest");
            convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate);
            this.useBiasClause = !cl.hasOption("no_bias");
        }
        this.lossFunction = LossFunction.resolve(lossFuncName);
        this.rankInit = FactorizedModel.RankInitScheme.resolve(rankInitOpt);
        this.rankInit.setMaxInitValue(maxInitValue);
        initStdDev = Math.max(initStdDev, 1.0 / (double)this.factor);
        this.rankInit.setInitStdDev(initStdDev);
        this.etaEstimator = EtaEstimator.get(cl);
        this.cvState = new ConversionState(conversionCheck, convergenceRate);
        return cl;
    }

    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length != 3 && argOIs.length != 4) {
            this.showHelp("train_bprmf UDTF takes 3 or 4 arguments: INT user, INT posItem, INT negItem [, CONSTANT STRING options]: " + Arrays.toString(argOIs));
        }
        this.userOI = HiveUtils.asIntCompatibleOI(argOIs, 0);
        this.posItemOI = HiveUtils.asIntCompatibleOI(argOIs, 1);
        this.negItemOI = HiveUtils.asIntCompatibleOI(argOIs, 2);
        this.processOptions(argOIs);
        this.model = new FactorizedModel(this, this.factor, this.rankInit);
        this.count = 0L;
        this.lastWritePos = 0L;
        this.uProbe = new float[this.factor];
        this.iProbe = new float[this.factor];
        this.jProbe = new float[this.factor];
        if (this.mapredContext != null && this.iterations > 1) {
            File file;
            try {
                file = File.createTempFile("hivemall_bprmf", ".sgmt");
                file.deleteOnExit();
                if (!file.canWrite()) {
                    throw new UDFArgumentException("Cannot write a temporary file: " + file.getAbsolutePath());
                }
            }
            catch (IOException ioe) {
                throw new UDFArgumentException((Throwable)ioe);
            }
            catch (Throwable e) {
                throw new UDFArgumentException(e);
            }
            this.fileIO = new NioFixedSegment(file, 12, false);
            this.inputBuf = ByteBuffer.allocateDirect(65536);
        }
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<Object> fieldOIs = new ArrayList<Object>();
        fieldNames.add("idx");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("Pu");
        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
        fieldNames.add("Qi");
        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
        if (this.useBiasClause) {
            fieldNames.add("Bi");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        }
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    public void process(Object[] args) throws HiveException {
        assert (args.length >= 3) : args.length;
        int u = PrimitiveObjectInspectorUtils.getInt((Object)args[0], (PrimitiveObjectInspector)this.userOI);
        int i = PrimitiveObjectInspectorUtils.getInt((Object)args[1], (PrimitiveObjectInspector)this.posItemOI);
        int j = PrimitiveObjectInspectorUtils.getInt((Object)args[2], (PrimitiveObjectInspector)this.negItemOI);
        BPRMatrixFactorizationUDTF.validateInput(u, i, j);
        this.beforeTrain(this.count, u, i, j);
        ++this.count;
        this.train(u, i, j);
    }

    protected void beforeTrain(long rowNum, int u, int i, int j) throws HiveException {
        if (this.inputBuf != null) {
            assert (this.fileIO != null);
            ByteBuffer buf = this.inputBuf;
            int remain = buf.remaining();
            if (remain < 12) {
                BPRMatrixFactorizationUDTF.writeBuffer(buf, this.fileIO, this.lastWritePos);
                this.lastWritePos = rowNum;
            }
            buf.putInt(u);
            buf.putInt(i);
            buf.putInt(j);
        }
    }

    protected void train(int u, int i, int j) {
        Rating[] user = this.model.getUserVector(u, true);
        Rating[] itemI = this.model.getItemVector(i, true);
        Rating[] itemJ = this.model.getItemVector(j, true);
        this.copyToProbe(user, this.uProbe);
        this.copyToProbe(itemI, this.iProbe);
        this.copyToProbe(itemJ, this.jProbe);
        double x_uij = this.predict(u, i, this.uProbe, this.iProbe) - this.predict(u, j, this.uProbe, this.jProbe);
        double dloss = this.dloss(x_uij, this.lossFunction);
        float eta = this.eta();
        int size = this.factor;
        for (int k = 0; k < size; ++k) {
            float w_uf = this.uProbe[k];
            float h_if = this.iProbe[k];
            float h_jf = this.jProbe[k];
            this.updateUserRating(user[k], w_uf, h_if, h_jf, dloss, eta);
            this.updateItemRating(itemI[k], w_uf, h_if, dloss, eta, this.regI);
            this.updateItemRating(itemJ[k], w_uf, h_jf, -dloss, eta, this.regJ);
        }
        if (this.useBiasClause) {
            this.updateBias(i, j, dloss, eta);
        }
    }

    protected double predict(int user, int item, @Nonnull float[] userProbe, @Nonnull float[] itemProbe) {
        double ret = this.model.getItemBias(item);
        int size = this.factor;
        for (int k = 0; k < size; ++k) {
            ret += (double)(userProbe[k] * itemProbe[k]);
        }
        if (!NumberUtils.isFinite(ret)) {
            throw new IllegalStateException("Detected " + ret + " in predict where user=" + user + " and item=" + item);
        }
        return ret;
    }

    protected double dloss(double x, @Nonnull LossFunction loss) {
        switch (loss) {
            case sigmoid: {
                return 1.0 / (1.0 + Math.exp(x));
            }
            case logistic: {
                double sigmoid = MathUtils.sigmoid(x);
                return sigmoid * (1.0 - sigmoid);
            }
            case lnLogistic: {
                double ex = Math.exp(-x);
                return ex / (1.0 + ex);
            }
        }
        throw new IllegalStateException("Unexpected loss function: " + (Object)((Object)loss));
    }

    protected float eta() {
        return this.etaEstimator.eta(this.count);
    }

    protected void updateUserRating(Rating rating, float w_uf, float h_if, float h_jf, double dloss, float eta) {
        double grad = dloss * (double)(h_if - h_jf) - (double)(this.regU * w_uf);
        float delta = (float)((double)eta * grad);
        float newWeight = w_uf + delta;
        if (!NumberUtils.isFinite(newWeight)) {
            throw new IllegalStateException("Detected " + newWeight + " for w_uf");
        }
        rating.setWeight(newWeight);
        this.cvState.incrLoss(this.regU * w_uf * w_uf);
    }

    protected void updateItemRating(Rating rating, float w_uf, float h_f, double dloss, float eta, float reg) {
        double grad = dloss * (double)w_uf - (double)(reg * h_f);
        float delta = (float)((double)eta * grad);
        float newWeight = h_f + delta;
        if (!NumberUtils.isFinite(newWeight)) {
            throw new IllegalStateException("Detected " + newWeight + " for h_f");
        }
        rating.setWeight(newWeight);
        this.cvState.incrLoss(reg * h_f * h_f);
    }

    protected void updateBias(int i, int j, double dloss, float eta) {
        float Bi = this.model.getItemBias(i);
        double Gi = dloss - (double)(this.regBias * Bi);
        if (!NumberUtils.isFinite(Bi = (float)((double)Bi + (double)eta * Gi))) {
            throw new IllegalStateException("Detected " + Bi + " for Bi");
        }
        this.model.setItemBias(i, Bi);
        this.cvState.incrLoss(this.regBias * Bi * Bi);
        float Bj = this.model.getItemBias(j);
        double Gj = -dloss - (double)(this.regBias * Bj);
        Bj = (float)((double)Bj + (double)eta * Gj);
        if (!NumberUtils.isFinite(Bj)) {
            throw new IllegalStateException("Detected " + Bj + " for Bj");
        }
        this.model.setItemBias(j, Bj);
        this.cvState.incrLoss(this.regBias * Bj * Bj);
    }

    public void close() throws HiveException {
        if (this.model != null) {
            if (this.count == 0L) {
                this.model = null;
                return;
            }
            if (this.iterations > 1) {
                this.runIterativeTraining(this.iterations);
            }
            IntWritable idx = new IntWritable();
            FloatWritable[] Pu = HiveUtils.newFloatArray(this.factor, 0.0f);
            FloatWritable[] Qi = HiveUtils.newFloatArray(this.factor, 0.0f);
            FloatWritable Bi = this.useBiasClause ? new FloatWritable() : null;
            Object[] forwardObj = new Object[]{idx, Pu, Qi, Bi};
            int numForwarded = 0;
            int maxIdx = this.model.getMaxIndex();
            for (int i = this.model.getMinIndex(); i <= maxIdx; ++i) {
                idx.set(i);
                Rating[] userRatings = this.model.getUserVector(i);
                if (userRatings == null) {
                    forwardObj[1] = null;
                } else {
                    forwardObj[1] = Pu;
                    BPRMatrixFactorizationUDTF.copyTo(userRatings, Pu);
                }
                Rating[] itemRatings = this.model.getItemVector(i);
                if (itemRatings == null) {
                    forwardObj[2] = null;
                } else {
                    forwardObj[2] = Qi;
                    BPRMatrixFactorizationUDTF.copyTo(itemRatings, Qi);
                }
                if (this.useBiasClause) {
                    Bi.set(this.model.getItemBias(i));
                }
                this.forward(forwardObj);
                ++numForwarded;
            }
            this.model = null;
            LOG.info((Object)("Forwarded the prediction model of " + numForwarded + " rows. [lastLosses=" + this.cvState.getCumulativeLoss() + ", #trainingExamples=" + this.count + "]"));
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private final void runIterativeTraining(@Nonnegative int iterations) throws HiveException {
        block30: {
            ByteBuffer inputBuf = this.inputBuf;
            NioFixedSegment fileIO = this.fileIO;
            assert (inputBuf != null);
            assert (fileIO != null);
            long numTrainingExamples = this.count;
            Reporter reporter = this.getReporter();
            Counters.Counter iterCounter = reporter == null ? null : reporter.getCounter("hivemall.factorization.mf.BPRMatrixFactorization$Counter", "iteration");
            try {
                int iter;
                if (this.lastWritePos == 0L) {
                    if (inputBuf.position() == 0) {
                        return;
                    }
                    inputBuf.flip();
                    for (iter = 2; iter <= iterations; ++iter) {
                        this.cvState.next();
                        BPRMatrixFactorizationUDTF.reportProgress(reporter);
                        BPRMatrixFactorizationUDTF.setCounterValue(iterCounter, iter);
                        while (inputBuf.remaining() > 0) {
                            int u = inputBuf.getInt();
                            int i = inputBuf.getInt();
                            int j = inputBuf.getInt();
                            ++this.count;
                            this.train(u, i, j);
                        }
                        this.cvState.multiplyLoss(0.5);
                        if (this.cvState.isConverged(numTrainingExamples)) break;
                        if (this.cvState.isLossIncreased()) {
                            this.etaEstimator.update(1.1f);
                        } else {
                            this.etaEstimator.update(0.5f);
                        }
                        inputBuf.rewind();
                    }
                    LOG.info((Object)("Performed " + this.cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(this.count) + " training updates in total) "));
                    break block30;
                }
                if (inputBuf.position() > 0) {
                    BPRMatrixFactorizationUDTF.writeBuffer(inputBuf, fileIO, this.lastWritePos);
                }
                try {
                    fileIO.flush();
                }
                catch (IOException e) {
                    throw new HiveException("Failed to flush a file: " + fileIO.getFile().getAbsolutePath(), (Throwable)e);
                }
                if (LOG.isInfoEnabled()) {
                    File tmpFile = fileIO.getFile();
                    LOG.info((Object)("Wrote " + numTrainingExamples + " records to a temporary file for iterative training: " + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + ")"));
                }
                for (iter = 2; iter <= iterations; ++iter) {
                    this.cvState.next();
                    BPRMatrixFactorizationUDTF.setCounterValue(iterCounter, iter);
                    inputBuf.clear();
                    long seekPos = 0L;
                    while (true) {
                        int remain;
                        int bytesRead;
                        BPRMatrixFactorizationUDTF.reportProgress(reporter);
                        try {
                            bytesRead = fileIO.read(seekPos, inputBuf);
                        }
                        catch (IOException e) {
                            throw new HiveException("Failed to read a file: " + fileIO.getFile().getAbsolutePath(), (Throwable)e);
                        }
                        if (bytesRead == 0) break;
                        assert (bytesRead > 0) : bytesRead;
                        seekPos += (long)bytesRead;
                        inputBuf.flip();
                        assert (remain > 0) : remain;
                        for (remain = inputBuf.remaining(); remain >= 12; remain -= 12) {
                            int u = inputBuf.getInt();
                            int i = inputBuf.getInt();
                            int j = inputBuf.getInt();
                            ++this.count;
                            this.train(u, i, j);
                        }
                        inputBuf.compact();
                    }
                    this.cvState.multiplyLoss(0.5);
                    if (this.cvState.isConverged(numTrainingExamples)) break;
                    if (this.cvState.isLossIncreased()) {
                        this.etaEstimator.update(1.1f);
                        continue;
                    }
                    this.etaEstimator.update(0.5f);
                }
                LOG.info((Object)("Performed " + this.cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples using a secondary storage (thus " + NumberUtils.formatNumber(this.count) + " training updates in total)"));
            }
            finally {
                try {
                    fileIO.close(true);
                }
                catch (IOException e) {
                    throw new HiveException("Failed to close a file: " + fileIO.getFile().getAbsolutePath(), (Throwable)e);
                }
                this.inputBuf = null;
                this.fileIO = null;
            }
        }
    }

    @Override
    public Rating newRating(float v) {
        return new Rating(v);
    }

    private static void validateInput(int u, int i, int j) throws HiveException {
        if (u < 0) {
            throw new HiveException("Illegal u index: " + u);
        }
        if (i < 0) {
            throw new HiveException("Illegal i index: " + i);
        }
        if (j < 0) {
            throw new HiveException("Illegal j index: " + j);
        }
    }

    private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioFixedSegment dst, long lastWritePos) throws HiveException {
        srcBuf.flip();
        try {
            dst.writeRecords(lastWritePos, srcBuf);
        }
        catch (IOException e) {
            throw new HiveException("Exception causes while writing records to : " + lastWritePos, (Throwable)e);
        }
        srcBuf.clear();
    }

    @Nonnull
    private final void copyToProbe(@Nonnull Rating[] rating, @Nonnull float[] probe) {
        int size = this.factor;
        for (int k = 0; k < size; ++k) {
            probe[k] = rating[k].getWeight();
        }
    }

    private static void copyTo(@Nonnull Rating[] rating, @Nonnull FloatWritable[] dst) {
        int size = rating.length;
        for (int k = 0; k < size; ++k) {
            float w = rating[k].getWeight();
            dst[k].set(w);
        }
    }

    public static enum LossFunction {
        sigmoid,
        logistic,
        lnLogistic;


        @Nonnull
        public static LossFunction resolve(@Nullable String name) {
            if (name == null) {
                return lnLogistic;
            }
            if (name.equalsIgnoreCase("lnLogistic")) {
                return lnLogistic;
            }
            if (name.equalsIgnoreCase("logistic")) {
                return logistic;
            }
            if (name.equalsIgnoreCase("sigmoid")) {
                return sigmoid;
            }
            throw new IllegalArgumentException("Unexpected loss function: " + name);
        }
    }
}

