/*
 * Decompiled with CFR 0.152.
 */
package hivemall.optimizer;

import hivemall.model.IWeightValue;
import hivemall.model.WeightValue;
import hivemall.optimizer.EtaEstimator;
import hivemall.optimizer.Regularization;
import hivemall.utils.lang.Primitives;
import hivemall.utils.math.MathUtils;
import java.util.HashMap;
import java.util.Map;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.concurrent.NotThreadSafe;

public interface Optimizer {
    public float update(@Nonnull Object var1, float var2, float var3, float var4);

    public void proceedStep();

    @Nonnull
    public String getOptimizerName();

    @Nonnull
    public Map<String, Object> getHyperParameters();

    public static abstract class AdagradRDA
    extends OptimizerBase {
        @Nonnull
        private final AdaGrad optimizerImpl;
        private final float lambda;

        public AdagradRDA(@Nonnull AdaGrad optimizerImpl, @Nonnull Map<String, String> options) {
            super(options);
            this.optimizerImpl = optimizerImpl;
            this.lambda = Primitives.parseFloat(options.get("lambda"), 1.0E-6f);
        }

        @Override
        protected WeightValue.WeightValueParamsF2 newWeightValue(float weight) {
            return new WeightValue.WeightValueParamsF2(weight, 0.0f, 0.0f);
        }

        @Override
        protected float update(@Nonnull IWeightValue weight, float gradient) {
            float new_sum_grad = weight.getSumOfGradients() + gradient;
            float sign = new_sum_grad > 0.0f ? 1.0f : -1.0f;
            float meansOfGradients = sign * new_sum_grad / (float)this._numStep - this.lambda;
            if (meansOfGradients < 0.0f) {
                weight.set(0.0f);
                weight.setSumOfSquaredGradients(0.0f);
                weight.setSumOfGradients(0.0f);
                return 0.0f;
            }
            float newWeight = -1.0f * sign * this._eta.eta(this._numStep) * (float)this._numStep * this.optimizerImpl.computeDelta(weight, meansOfGradients);
            weight.set(newWeight);
            weight.setSumOfGradients(new_sum_grad);
            return newWeight;
        }

        @Override
        public String getOptimizerName() {
            return "adagrad_rda";
        }

        @Override
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> params = this.optimizerImpl.getHyperParameters();
            params.put("optimizer", this.getOptimizerName());
            params.put("lambda", Float.valueOf(this.lambda));
            return params;
        }
    }

    public static abstract class AdamHD
    extends Adam {
        private final float beta;
        protected double deltaU = 0.0;

        public AdamHD(@Nonnull Map<String, String> options) {
            super(options);
            this.alpha = Primitives.parseFloat(options.get("alpha"), 0.02f);
            this.beta = Primitives.parseFloat(options.get("beta"), 1.0E-6f);
        }

        @Override
        protected final EtaEstimator getEtaEstimator(@Nonnull Map<String, String> options) {
            if (!options.containsKey("eta")) {
                options.put("eta", "fixed");
            }
            if (!options.containsKey("eta0")) {
                options.put("eta0", "1.0");
            }
            return super.getEtaEstimator(options);
        }

        private float alpha(float gradient, double deltaU) {
            double h = (double)gradient * deltaU;
            if (h > 0.0) {
                this.alpha *= 1.0f - this.beta;
            } else if (h < 0.0) {
                this.alpha *= 1.0f + this.beta;
            }
            return this.alpha;
        }

        @Override
        protected float computeDelta(@Nonnull IWeightValue weight, float gradient) {
            if (this.decay != 0.0f) {
                float oldWeight = weight.get();
                gradient += this.decay * oldWeight;
            }
            float m = this.beta1 * weight.getM() + (1.0f - this.beta1) * gradient;
            float v = this.beta2 * weight.getV() + (float)((double)(1.0f - this.beta2) * MathUtils.square(gradient));
            double m_hat = (double)m / (1.0 - Math.pow(this.beta1, this._numStep));
            double v_hat = (double)v / (1.0 - Math.pow(this.beta2, this._numStep));
            float alpha_t = this.alpha(gradient, this.deltaU);
            double deltaU = m_hat / (Math.sqrt(v_hat) + (double)this.eps);
            float delta = (float)((double)alpha_t * deltaU);
            this.deltaU = deltaU;
            if (this.decay != 0.0f) {
                float oldWeight = weight.get();
                delta += this.decay * oldWeight;
            }
            weight.setM(m);
            weight.setV(v);
            return delta;
        }

        @Override
        public String getOptimizerName() {
            return "adam_hd";
        }

        @Override
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> params = super.getHyperParameters();
            params.put("beta", Float.valueOf(this.beta));
            return params;
        }
    }

    public static abstract class Eve
    extends Adam {
        protected final float beta3;
        private float c = 10.0f;
        private float inv_c = 0.1f;
        private float currLoss;
        private float prevLoss = 0.0f;
        private double prevDt = 1.0;

        public Eve(@Nonnull Map<String, String> options) {
            super(options);
            this.beta3 = Primitives.parseFloat(options.get("beta3"), 0.999f);
            this.c = Primitives.parseFloat(options.get("c"), 10.0f);
            this.inv_c = 1.0f / this.c;
        }

        @Override
        protected double alpha() {
            double fix1 = 1.0 - Math.pow(this.beta1, this._numStep);
            double fix2 = 1.0 - Math.pow(this.beta2, this._numStep);
            double fix = Math.sqrt(fix2) / fix1;
            double alpha_t = (double)this.alpha * fix;
            if (this._numStep > 1L && this.currLoss != this.prevLoss) {
                double d = Math.abs(this.currLoss - this.prevLoss) / Math.min(this.currLoss, this.prevLoss);
                d = MathUtils.clip(d, (double)this.inv_c, (double)this.c);
                this.prevDt = d = (double)this.beta3 * this.prevDt + (1.0 - (double)this.beta3) * d;
                alpha_t /= d;
            }
            return alpha_t;
        }

        @Override
        public float update(Object feature, float weight, float loss, float gradient) {
            this.currLoss = loss;
            float delta = this.update(feature, weight, gradient);
            this.prevLoss = loss;
            return delta;
        }

        @Override
        public String getOptimizerName() {
            return "eve";
        }

        @Override
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> params = super.getHyperParameters();
            params.put("beta3", Float.valueOf(this.beta3));
            params.put("c", Float.valueOf(this.c));
            return params;
        }
    }

    public static abstract class Nadam
    extends OptimizerBase {
        protected float alpha;
        protected final float beta1;
        protected final float beta2;
        protected final float eps;
        protected final float decay;
        protected final float scheduleDecay;
        protected double mu_t;
        protected double mu_t_1;
        protected double mu_product = 1.0;
        protected double mu_product_next = 1.0;

        public Nadam(@Nonnull Map<String, String> options) {
            super(options);
            this.alpha = Primitives.parseFloat(options.get("alpha"), 1.0f);
            this.beta1 = Primitives.parseFloat(options.get("beta1"), 0.9f);
            this.beta2 = Primitives.parseFloat(options.get("beta2"), 0.999f);
            this.eps = Primitives.parseFloat(options.get("eps"), 1.0E-8f);
            this.decay = Primitives.parseFloat(options.get("decay"), 0.0f);
            this.scheduleDecay = Primitives.parseFloat(options.get("scheduleDecay"), 0.004f);
        }

        @Override
        protected WeightValue.WeightValueParamsF2 newWeightValue(float weight) {
            return new WeightValue.WeightValueParamsF2(weight, 0.0f, 0.0f);
        }

        @Override
        public void proceedStep() {
            long t;
            this._numStep = t = this._numStep + 1L;
            double mu_product_prev = this.mu_product;
            double mu_t = (double)this.beta1 * (1.0 - 0.5 * Math.pow(0.96, Math.floor((float)t * this.scheduleDecay) + 1.0));
            double mu_t_1 = (double)this.beta1 * (1.0 - 0.5 * Math.pow(0.96, Math.floor(((double)t + 1.0) * (double)this.scheduleDecay) + 1.0));
            this.mu_t = mu_t;
            this.mu_t_1 = mu_t_1;
            this.mu_product = mu_product_prev * mu_t;
            this.mu_product_next = mu_product_prev * mu_t * mu_t_1;
        }

        @Override
        protected float eta(long t) {
            double fix1 = 1.0 - Math.pow(this.beta1, t);
            double fix2 = 1.0 - Math.pow(this.beta2, t);
            float eta = this._eta.eta(t);
            double fix = Math.sqrt(fix2) / fix1;
            return (float)((double)eta * fix);
        }

        protected double alpha() {
            double fix1 = 1.0 - Math.pow(this.beta1, this._numStep);
            double fix2 = 1.0 - Math.pow(this.beta2, this._numStep);
            double fix = Math.sqrt(fix2) / fix1;
            return (double)this.alpha * fix;
        }

        @Override
        protected float computeDelta(@Nonnull IWeightValue weight, float gradient) {
            if (this.decay != 0.0f) {
                float oldWeight = weight.get();
                gradient += this.decay * oldWeight;
            }
            float m = this.beta1 * weight.getM() + (1.0f - this.beta1) * gradient;
            double m_hat = (double)m / (1.0 - this.mu_product_next);
            float v = this.beta2 * weight.getV() + (float)((1.0 - (double)this.beta2) * MathUtils.square(gradient));
            double v_hat = (double)v / (1.0 - Math.pow(this.beta2, this._numStep));
            double g_hat = (double)gradient / (1.0 - this.mu_product);
            double m_bar = (1.0 - this.mu_t) * g_hat + this.mu_t_1 * m_hat;
            double deltaU = m_bar / (Math.sqrt(v_hat) + (double)this.eps);
            double alpha_t = this.alpha();
            float delta = (float)(alpha_t * deltaU);
            if ((double)this.decay != 0.0) {
                float oldWeight = weight.get();
                delta += this.decay * oldWeight;
            }
            weight.setM(m);
            weight.setV(v);
            return delta;
        }

        @Override
        public String getOptimizerName() {
            return "nadam";
        }

        @Override
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> params = super.getHyperParameters();
            params.put("alpha", Float.valueOf(this.alpha));
            params.put("beta1", Float.valueOf(this.beta1));
            params.put("beta2", Float.valueOf(this.beta2));
            params.put("eps", Float.valueOf(this.eps));
            params.put("decay", Float.valueOf(this.decay));
            params.put("scheduleDecay", Float.valueOf(this.scheduleDecay));
            return params;
        }
    }

    public static abstract class Adam
    extends OptimizerBase {
        protected float alpha;
        protected final float beta1;
        protected final float beta2;
        protected final float eps;
        protected final float decay;
        protected final boolean amsgrad;
        protected float max_vhat = Float.MIN_VALUE;

        public Adam(@Nonnull Map<String, String> options) {
            super(options);
            this.alpha = Primitives.parseFloat(options.get("alpha"), 1.0f);
            this.beta1 = Primitives.parseFloat(options.get("beta1"), 0.9f);
            this.beta2 = Primitives.parseFloat(options.get("beta2"), 0.999f);
            this.eps = Primitives.parseFloat(options.get("eps"), 1.0E-8f);
            this.decay = Primitives.parseFloat(options.get("decay"), 0.0f);
            this.amsgrad = options.containsKey("amsgrad");
        }

        @Override
        protected WeightValue.WeightValueParamsF2 newWeightValue(float weight) {
            return new WeightValue.WeightValueParamsF2(weight, 0.0f, 0.0f);
        }

        @Override
        protected float eta(long t) {
            double fix1 = 1.0 - Math.pow(this.beta1, t);
            double fix2 = 1.0 - Math.pow(this.beta2, t);
            float eta = this._eta.eta(t);
            double fix = Math.sqrt(fix2) / fix1;
            return (float)((double)eta * fix);
        }

        protected double alpha() {
            double fix1 = 1.0 - Math.pow(this.beta1, this._numStep);
            double fix2 = 1.0 - Math.pow(this.beta2, this._numStep);
            double fix = Math.sqrt(fix2) / fix1;
            return (double)this.alpha * fix;
        }

        @Override
        protected float computeDelta(@Nonnull IWeightValue weight, float gradient) {
            float v;
            if (this.decay != 0.0f) {
                float oldWeight = weight.get();
                gradient += this.decay * oldWeight;
            }
            float m = this.beta1 * weight.getM() + (1.0f - this.beta1) * gradient;
            float v_hat = v = this.beta2 * weight.getV() + (float)((double)(1.0f - this.beta2) * MathUtils.square(gradient));
            if (this.amsgrad) {
                if (v_hat > this.max_vhat) {
                    this.max_vhat = v_hat;
                } else {
                    v_hat = this.max_vhat;
                }
            }
            double deltaU = (double)m / (Math.sqrt(v_hat) + (double)this.eps);
            double alpha_t = this.alpha();
            float delta = (float)(alpha_t * deltaU);
            if (this.decay != 0.0f) {
                float oldWeight = weight.get();
                delta += this.decay * oldWeight;
            }
            weight.setM(m);
            weight.setV(v);
            return delta;
        }

        @Override
        public String getOptimizerName() {
            return this.amsgrad ? "adam-amsgrad" : "adam";
        }

        @Override
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> params = super.getHyperParameters();
            params.put("alpha", Float.valueOf(this.alpha));
            params.put("beta1", Float.valueOf(this.beta1));
            params.put("beta2", Float.valueOf(this.beta2));
            params.put("eps", Float.valueOf(this.eps));
            params.put("decay", Float.valueOf(this.decay));
            return params;
        }
    }

    public static abstract class AdaDelta
    extends OptimizerBase {
        private final float decay;
        private final float eps;
        private final float scale;

        public AdaDelta(@Nonnull Map<String, String> options) {
            super(options);
            this.decay = Primitives.parseFloat(options.get("decay"), 0.95f);
            this.eps = Primitives.parseFloat(options.get("eps"), 1.0E-6f);
            this.scale = Primitives.parseFloat(options.get("scale"), 100.0f);
        }

        @Override
        protected WeightValue.WeightValueParamsF2 newWeightValue(float weight) {
            return new WeightValue.WeightValueParamsF2(weight, 0.0f, 0.0f);
        }

        @Override
        protected final EtaEstimator getEtaEstimator(@Nonnull Map<String, String> options) {
            if (!options.containsKey("eta")) {
                options.put("eta", "fixed");
            }
            if (!options.containsKey("eta0")) {
                options.put("eta0", "1.0");
            }
            return super.getEtaEstimator(options);
        }

        @Override
        protected float computeDelta(@Nonnull IWeightValue weight, float gradient) {
            float old_scaled_sum_sqgrad = weight.getSumOfSquaredGradients();
            float old_sum_squared_delta_x = weight.getSumOfSquaredDeltaX();
            float new_scaled_sum_sqgrad = this.decay * old_scaled_sum_sqgrad + (1.0f - this.decay) * gradient * (gradient / this.scale);
            float delta = (float)Math.sqrt((double)(old_sum_squared_delta_x + this.eps) / ((double)new_scaled_sum_sqgrad * (double)this.scale + (double)this.eps)) * gradient;
            float new_sum_squared_delta_x = this.decay * old_sum_squared_delta_x + (1.0f - this.decay) * delta * delta;
            weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad);
            weight.setSumOfSquaredDeltaX(new_sum_squared_delta_x);
            return delta;
        }

        @Override
        public String getOptimizerName() {
            return "adadelta";
        }

        @Override
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> params = super.getHyperParameters();
            params.put("decay", Float.valueOf(this.decay));
            params.put("eps", Float.valueOf(this.eps));
            params.put("scale", Float.valueOf(this.scale));
            return params;
        }
    }

    public static abstract class RMSpropGraves
    extends OptimizerBase {
        private final float decay;
        private final float alpha;
        private final float momentum;
        private final float eps;
        private final float scale;

        public RMSpropGraves(@Nonnull Map<String, String> options) {
            super(options);
            this.decay = Primitives.parseFloat(options.get("decay"), 0.95f);
            this.alpha = Primitives.parseFloat(options.get("alpha"), 1.0f);
            this.momentum = Primitives.parseFloat(options.get("momentum"), 0.9f);
            this.eps = Primitives.parseFloat(options.get("eps"), 1.0f);
            this.scale = Primitives.parseFloat(options.get("scale"), 100.0f);
        }

        @Override
        protected WeightValue.WeightValueParamsF3 newWeightValue(float weight) {
            return new WeightValue.WeightValueParamsF3(weight, 0.0f, 0.0f, 0.0f);
        }

        @Override
        protected float computeDelta(@Nonnull IWeightValue weight, float gradient) {
            float old_scaled_n = weight.getSumOfSquaredGradients();
            float new_scaled_n = this.decay * old_scaled_n + (1.0f - this.decay) * gradient * (gradient / this.scale);
            weight.setSumOfSquaredGradients(new_scaled_n);
            float old_scaled_g = weight.getSumOfGradients();
            float new_scaled_g = this.decay * old_scaled_g + (1.0f - this.decay) * gradient / this.scale;
            weight.setSumOfGradients(new_scaled_g);
            double n = (double)old_scaled_n * (double)this.scale;
            double g = (double)new_scaled_g * (double)this.scale;
            float oldDelta = weight.getDelta();
            float delta = this.momentum * oldDelta + this.alpha * (float)((double)gradient / Math.sqrt(n - g * g + (double)this.eps));
            weight.setDelta(delta);
            return delta;
        }

        @Override
        public String getOptimizerName() {
            return "rmsprop_graves";
        }

        @Override
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> params = super.getHyperParameters();
            params.put("decay", Float.valueOf(this.decay));
            params.put("alpha", Float.valueOf(this.alpha));
            params.put("momentum", Float.valueOf(this.momentum));
            params.put("eps", Float.valueOf(this.eps));
            return params;
        }
    }

    public static abstract class RMSprop
    extends OptimizerBase {
        private final float decay;
        private final float eps;
        private final float scale;

        public RMSprop(@Nonnull Map<String, String> options) {
            super(options);
            this.decay = Primitives.parseFloat(options.get("decay"), 0.95f);
            this.eps = Primitives.parseFloat(options.get("eps"), 1.0f);
            this.scale = Primitives.parseFloat(options.get("scale"), 100.0f);
        }

        @Override
        protected WeightValue.WeightValueParamsF1 newWeightValue(float weight) {
            return new WeightValue.WeightValueParamsF1(weight, 0.0f);
        }

        @Override
        protected float computeDelta(@Nonnull IWeightValue weight, float gradient) {
            float old_scaled_gg = weight.getSumOfSquaredGradients();
            float new_scaled_gg = this.decay * old_scaled_gg + (1.0f - this.decay) * gradient * (gradient / this.scale);
            weight.setSumOfSquaredGradients(new_scaled_gg);
            return (float)((double)gradient / Math.sqrt((double)this.eps + (double)old_scaled_gg * (double)this.scale));
        }

        @Override
        public String getOptimizerName() {
            return "rmsprop";
        }

        @Override
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> params = super.getHyperParameters();
            params.put("decay", Float.valueOf(this.decay));
            params.put("eps", Float.valueOf(this.eps));
            params.put("scale", Float.valueOf(this.scale));
            return params;
        }
    }

    public static abstract class AdaGrad
    extends OptimizerBase {
        private final float eps;
        private final float scale;

        public AdaGrad(@Nonnull Map<String, String> options) {
            super(options);
            this.eps = Primitives.parseFloat(options.get("eps"), 1.0f);
            this.scale = Primitives.parseFloat(options.get("scale"), 100.0f);
        }

        @Override
        protected WeightValue.WeightValueParamsF1 newWeightValue(float weight) {
            return new WeightValue.WeightValueParamsF1(weight, 0.0f);
        }

        @Override
        protected float computeDelta(@Nonnull IWeightValue weight, float gradient) {
            float old_scaled_gg = weight.getSumOfSquaredGradients();
            float new_scaled_gg = old_scaled_gg + gradient * (gradient / this.scale);
            weight.setSumOfSquaredGradients(new_scaled_gg);
            return (float)((double)gradient / Math.sqrt((double)this.eps + (double)old_scaled_gg * (double)this.scale));
        }

        @Override
        public String getOptimizerName() {
            return "adagrad";
        }

        @Override
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> params = super.getHyperParameters();
            params.put("eps", Float.valueOf(this.eps));
            params.put("scale", Float.valueOf(this.scale));
            return params;
        }
    }

    public static abstract class Momentum
    extends OptimizerBase {
        @Nonnull
        private final WeightValue.WeightValueParamsF1 weightValueReused = this.newWeightValue(0.0f);
        private final boolean nesterov;
        private final float alpha;
        private final float momentum;

        public Momentum(@Nonnull Map<String, String> options) {
            super(options);
            this.nesterov = options.containsKey("nesterov");
            this.alpha = Primitives.parseFloat(options.get("alpha"), 1.0f);
            this.momentum = Primitives.parseFloat(options.get("momentum"), 0.9f);
        }

        @Override
        protected WeightValue.WeightValueParamsF1 newWeightValue(float weight) {
            return new WeightValue.WeightValueParamsF1(weight, 0.0f);
        }

        @Override
        protected float computeDelta(@Nonnull IWeightValue weight, float gradient) {
            float oldDelta = weight.getDelta();
            float v = this.momentum * oldDelta + this.alpha * gradient;
            weight.setDelta(v);
            if (this.nesterov) {
                return this.momentum * this.momentum * v + (1.0f + this.momentum) * this.alpha * gradient;
            }
            return v;
        }

        @Override
        public String getOptimizerName() {
            return this.nesterov ? "nesterov" : "momentum";
        }

        @Override
        public Map<String, Object> getHyperParameters() {
            Map<String, Object> params = super.getHyperParameters();
            params.put("nesterov", this.nesterov);
            params.put("alpha", Float.valueOf(this.alpha));
            params.put("momentum", Float.valueOf(this.momentum));
            return params;
        }
    }

    public static final class SGD
    extends OptimizerBase {
        private final IWeightValue weightValueReused = this.newWeightValue(0.0f);

        public SGD(@Nonnull Map<String, String> options) {
            super(options);
        }

        @Override
        protected WeightValue newWeightValue(float weight) {
            return new WeightValue(weight);
        }

        @Override
        protected float update(@Nonnull Object feature, float weight, float gradient) {
            this.weightValueReused.set(weight);
            this.update(this.weightValueReused, gradient);
            return this.weightValueReused.get();
        }

        @Override
        public String getOptimizerName() {
            return "sgd";
        }
    }

    @NotThreadSafe
    public static abstract class OptimizerBase
    implements Optimizer {
        @Nonnull
        protected final EtaEstimator _eta;
        @Nonnull
        protected final Regularization _reg;
        @Nonnegative
        protected long _numStep = 0L;

        public OptimizerBase(@Nonnull Map<String, String> options) {
            this._eta = this.getEtaEstimator(options);
            this._reg = Regularization.get(options);
        }

        @Nonnull
        protected abstract IWeightValue newWeightValue(float var1);

        @Nonnull
        protected EtaEstimator getEtaEstimator(@Nonnull Map<String, String> options) {
            return EtaEstimator.get(options);
        }

        @Override
        public void proceedStep() {
            ++this._numStep;
        }

        @Override
        public float update(@Nonnull Object feature, float weight, float loss, float gradient) {
            return this.update(feature, weight, gradient);
        }

        protected abstract float update(@Nonnull Object var1, float var2, float var3);

        protected float update(@Nonnull IWeightValue weight, float gradient) {
            float oldWeight = weight.get();
            float delta = this.computeDelta(weight, gradient);
            float eta = this.eta(this._numStep);
            float reg = this._reg.regularize(oldWeight, delta);
            float newWeight = oldWeight - eta * reg;
            weight.set(newWeight);
            return newWeight;
        }

        protected float eta(long t) {
            return this._eta.eta(this._numStep);
        }

        protected float computeDelta(@Nonnull IWeightValue weight, float gradient) {
            return gradient;
        }

        @Override
        public Map<String, Object> getHyperParameters() {
            HashMap<String, Object> params = new HashMap<String, Object>();
            params.put("optimizer", this.getOptimizerName());
            this._eta.getHyperParameters(params);
            this._reg.getHyperParameters(params);
            return params;
        }
    }
}

