/*
 * Decompiled with CFR 0.152.
 */
package hivemall.regression;

import hivemall.annotations.Cite;
import hivemall.model.FeatureValue;
import hivemall.model.IWeightValue;
import hivemall.model.PredictionResult;
import hivemall.model.WeightValue;
import hivemall.optimizer.LossFunctions;
import hivemall.regression.RegressionBaseUDTF;
import hivemall.utils.stats.OnlineVariance;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;

@Description(name="train_arow_regr", value="_FUNC_(array<int|bigint|string> features, float target [, constant string options]) - a standard AROW (Adaptive Reguralization of Weight Vectors) regressor that uses `y - w^Tx` for the loss function.", extended="SELECT \n  feature,\n  argmin_kld(weight, covar) as weight\nFROM (\n  SELECT \n     train_arow_regr(features,label) as (feature,weight,covar)\n  FROM \n     training_data\n ) t \nGROUP BY feature")
@Cite(description="K. Crammer, A. Kulesza, and M. Dredze, \"Adaptive Regularization of Weight Vectors\", In Proc. NIPS, 2009.", url="https://papers.nips.cc/paper/3848-adaptive-regularization-of-weight-vectors.pdf")
public class AROWRegressionUDTF
extends RegressionBaseUDTF {
    protected float r;

    @Override
    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        int numArgs = argOIs.length;
        if (numArgs != 2 && numArgs != 3) {
            throw new UDFArgumentException("_FUNC_ takes arguments: List<Int|BigInt|Text> features, float target [, constant string options]");
        }
        return super.initialize(argOIs);
    }

    @Override
    protected boolean useCovariance() {
        return true;
    }

    @Override
    protected Options getOptions() {
        Options opts = super.getOptions();
        opts.addOption("r", "regularization", true, "Regularization parameter for some r > 0 [default 0.1]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        String r_str;
        CommandLine cl = super.processOptions(argOIs);
        float r = 0.1f;
        if (cl != null && (r_str = cl.getOptionValue("r")) != null && !((r = Float.parseFloat(r_str)) > 0.0f)) {
            throw new UDFArgumentException("Regularization parameter must be greater than 0: " + r_str);
        }
        this.r = r;
        return cl;
    }

    @Override
    protected void train(@Nonnull FeatureValue[] features, float target) {
        PredictionResult margin = this.calcScoreAndVariance(features);
        float predicted = margin.getScore();
        float loss = this.loss(target, predicted);
        float var = margin.getVariance();
        float beta = 1.0f / (var + this.r);
        this.update(features, loss, beta);
    }

    protected float loss(float target, float predicted) {
        return target - predicted;
    }

    @Override
    protected void update(@Nonnull FeatureValue[] features, float coeff, float beta) {
        for (FeatureValue f : features) {
            if (f == null) continue;
            Object k = f.getFeature();
            float v = f.getValueAsFloat();
            Object old_w = this.model.get(k);
            IWeightValue new_w = AROWRegressionUDTF.getNewWeight(old_w, v, coeff, beta);
            this.model.set(k, new_w);
        }
    }

    private static IWeightValue getNewWeight(IWeightValue old, float x, float coeff, float beta) {
        float old_cov;
        float old_w;
        if (old == null) {
            old_w = 0.0f;
            old_cov = 1.0f;
        } else {
            old_w = old.get();
            old_cov = old.getCovariance();
        }
        float cov_x = old_cov * x;
        float new_w = old_w + coeff * cov_x * beta;
        float new_cov = old_cov - beta * cov_x * cov_x;
        return new WeightValue.WeightValueWithCovar(new_w, new_cov);
    }

    @Description(name="train_arowe2_regr", value="_FUNC_(array<int|bigint|string> features, float target [, constant string options]) - a refined version of AROW (Adaptive Reguralization of Weight Vectors) regressor that usages adaptive epsilon-insensitive hinge loss `|w^t - y| - epsilon * stddev` for the loss function", extended="SELECT \n  feature,\n  argmin_kld(weight, covar) as weight\nFROM (\n  SELECT \n     train_arowe2_regr(features,label) as (feature,weight,covar)\n  FROM \n     training_data\n ) t \nGROUP BY feature")
    public static class AROWe2
    extends AROWe {
        private OnlineVariance targetStdDev;

        @Override
        public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
            this.targetStdDev = new OnlineVariance();
            return super.initialize(argOIs);
        }

        @Override
        protected void preTrain(float target) {
            this.targetStdDev.handle(target);
        }

        @Override
        protected float loss(float target, float predicted) {
            float stddev = (float)this.targetStdDev.stddev();
            float e = this.epsilon * stddev;
            return LossFunctions.epsilonInsensitiveLoss(predicted, target, e);
        }
    }

    @Description(name="train_arowe_regr", value="_FUNC_(array<int|bigint|string> features, float target [, constant string options]) - a refined version of AROW (Adaptive Reguralization of Weight Vectors) regressor that usages epsilon-insensitive hinge loss `|w^t - y| - epsilon` for the loss function", extended="SELECT \n  feature,\n  argmin_kld(weight, covar) as weight\nFROM (\n  SELECT \n     train_arowe_regr(features,label) as (feature,weight,covar)\n  FROM \n     training_data\n ) t \nGROUP BY feature")
    public static class AROWe
    extends AROWRegressionUDTF {
        protected float epsilon;

        @Override
        protected Options getOptions() {
            Options opts = super.getOptions();
            opts.addOption("e", "epsilon", true, "Sensitivity to prediction mistakes [default 0.1]");
            return opts;
        }

        @Override
        protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
            String opt_epsilon;
            CommandLine cl = super.processOptions(argOIs);
            float epsilon = 0.1f;
            if (cl != null && (opt_epsilon = cl.getOptionValue("epsilon")) != null) {
                epsilon = Float.parseFloat(opt_epsilon);
            }
            this.epsilon = epsilon;
            return cl;
        }

        @Override
        protected void train(@Nonnull FeatureValue[] features, float target) {
            this.preTrain(target);
            PredictionResult margin = this.calcScoreAndVariance(features);
            float predicted = margin.getScore();
            float loss = this.loss(target, predicted);
            if (loss > 0.0f) {
                float coeff = target - predicted > 0.0f ? loss : -loss;
                float var = margin.getVariance();
                float beta = 1.0f / (var + this.r);
                this.update(features, coeff, beta);
            }
        }

        protected void preTrain(float target) {
        }

        @Override
        protected float loss(float target, float predicted) {
            return LossFunctions.epsilonInsensitiveLoss(predicted, target, this.epsilon);
        }
    }
}

