/*
 * Decompiled with CFR 0.152.
 */
package hivemall.factorization.mf;

import hivemall.factorization.mf.OnlineMatrixFactorizationUDTF;
import hivemall.factorization.mf.Rating;
import hivemall.utils.lang.Primitives;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;

@Description(name="train_mf_adagrad", value="_FUNC_(INT user, INT item, FLOAT rating [, CONSTANT STRING options]) - Returns a relation consists of <int idx, array<float> Pu, array<float> Qi [, float Bu, float Bi [, float mu]]>")
public final class MatrixFactorizationAdaGradUDTF
extends OnlineMatrixFactorizationUDTF {
    private float eta;
    private float eps;
    private float scaling;

    @Override
    protected Options getOptions() {
        Options opts = super.getOptions();
        opts.addOption("eta", "eta0", true, "The initial learning rate [default 1.0]");
        opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1.0]");
        opts.addOption("scale", true, "Internal scaling/descaling factor for cumulative weights [100]");
        return opts;
    }

    @Override
    public Rating newRating(float v) {
        return new Rating.RatingWithSquaredGrad(v);
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl = super.processOptions(argOIs);
        if (cl == null) {
            this.eta = 1.0f;
            this.eps = 1.0f;
            this.scaling = 100.0f;
        } else {
            this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.0f);
            this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), 1.0f);
            this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100.0f);
        }
        return cl;
    }

    @Override
    protected void updateItemRating(Rating rating, float Pu, float Qi, double err, float eta) {
        double gradient = err * (double)Pu - (double)(this.lambda * Qi);
        this.updateRating(rating, Qi, gradient);
        this.cvState.incrLoss(this.lambda * Qi * Qi);
    }

    @Override
    protected void updateUserRating(Rating rating, float Pu, float Qi, double err, float eta) {
        double gradient = err * (double)Qi - (double)(this.lambda * Pu);
        this.updateRating(rating, Pu, gradient);
        this.cvState.incrLoss(this.lambda * Pu * Pu);
    }

    @Override
    protected void updateMeanRating(double err, float eta) {
        assert (this.updateMeanRating);
        Rating mean = this.model.meanRating();
        float oldMean = mean.getWeight();
        this.updateRating(mean, oldMean, err);
    }

    @Override
    protected void updateBias(int user, int item, double err, float eta) {
        Rating ratingBu = this.model.userBias(user);
        float Bu = ratingBu.getWeight();
        double Gu = err - (double)(this.lambda * Bu);
        this.updateRating(ratingBu, Bu, Gu);
        this.cvState.incrLoss(this.lambda * Bu * Bu);
        Rating ratingBi = this.model.itemBias(item);
        float Bi = ratingBi.getWeight();
        double Gi = err - (double)(this.lambda * Bi);
        this.updateRating(ratingBi, Bi, Gi);
        this.cvState.incrLoss(this.lambda * Bi * Bi);
    }

    private void updateRating(Rating rating, float oldWeight, double gradient) {
        double gg = gradient * (gradient / (double)this.scaling);
        double scaled_sum_gg = rating.getSumOfSquaredGradients() + gg;
        float delta = (float)((double)this.eta(scaled_sum_gg) * gradient);
        float newWeight = oldWeight + delta;
        rating.setWeight(newWeight);
        rating.setSumOfSquaredGradients(scaled_sum_gg);
    }

    private float eta(double scaledSumOfSquaredGradients) {
        double sumOfSquaredGradients = scaledSumOfSquaredGradients * (double)this.scaling;
        return this.eta / (float)Math.sqrt((double)this.eps + sumOfSquaredGradients);
    }
}

