/*
 * Decompiled with CFR 0.152.
 */
package hivemall.optimizer;

import hivemall.model.WeightValue;
import hivemall.optimizer.Optimizer;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.math.MathUtils;
import java.util.Arrays;
import java.util.Map;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public final class DenseOptimizerFactory {
    private static final Log LOG = LogFactory.getLog(DenseOptimizerFactory.class);

    @Nonnull
    public static Optimizer create(@Nonnegative int ndims, @Nonnull Map<String, String> options) {
        Optimizer.OptimizerBase optimizerImpl;
        String optimizerName = options.get("optimizer");
        if (optimizerName == null) {
            throw new IllegalArgumentException("`optimizer` not defined");
        }
        String name = optimizerName.toLowerCase();
        if ("rda".equalsIgnoreCase(options.get("regularization")) && !"adagrad".equals(name)) {
            throw new IllegalArgumentException("`-regularization rda` is only supported for AdaGrad but `-optimizer " + optimizerName + "`. Please specify `-regularization l1` and so on.");
        }
        if ("sgd".equals(name)) {
            optimizerImpl = new Optimizer.SGD(options);
        } else if ("momentum".equals(name)) {
            optimizerImpl = new Momentum(ndims, options);
        } else if ("nesterov".equals(name)) {
            options.put("nesterov", "");
            optimizerImpl = new Momentum(ndims, options);
        } else if ("adagrad".equals(name)) {
            if ("rda".equalsIgnoreCase(options.get("regularization"))) {
                AdaGrad adagrad = new AdaGrad(ndims, options);
                optimizerImpl = new AdagradRDA(ndims, adagrad, options);
            } else {
                optimizerImpl = new AdaGrad(ndims, options);
            }
        } else if ("rmsprop".equals(name)) {
            optimizerImpl = new RMSprop(ndims, options);
        } else if ("rmspropgraves".equals(name) || "rmsprop_graves".equals(name)) {
            optimizerImpl = new RMSpropGraves(ndims, options);
        } else if ("adadelta".equals(name)) {
            optimizerImpl = new AdaDelta(ndims, options);
        } else if ("adam".equals(name)) {
            optimizerImpl = new Adam(ndims, options);
        } else if ("nadam".equals(name)) {
            optimizerImpl = new Nadam(ndims, options);
        } else if ("eve".equals(name)) {
            optimizerImpl = new Eve(ndims, options);
        } else if ("adam_hd".equals(name) || "adamhd".equals(name)) {
            optimizerImpl = new AdamHD(ndims, options);
        } else {
            throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName);
        }
        if (LOG.isInfoEnabled()) {
            LOG.info((Object)("Configured " + optimizerImpl.getOptimizerName() + " as the optimizer: " + options));
            LOG.info((Object)("ETA estimator: " + optimizerImpl._eta));
        }
        return optimizerImpl;
    }

    @NotThreadSafe
    static final class AdagradRDA
    extends Optimizer.AdagradRDA {
        @Nonnull
        private final WeightValue.WeightValueParamsF2 weightValueReused = this.newWeightValue(0.0f);
        @Nonnull
        private float[] sum_of_gradients;

        public AdagradRDA(int ndims, @Nonnull Optimizer.AdaGrad optimizerImpl, @Nonnull Map<String, String> options) {
            super(optimizerImpl, options);
            this.sum_of_gradients = new float[ndims];
        }

        @Override
        protected float update(@Nonnull Object feature, float weight, float gradient) {
            int i = HiveUtils.parseInt(feature);
            this.ensureCapacity(i);
            this.weightValueReused.set(weight);
            this.weightValueReused.setSumOfGradients(this.sum_of_gradients[i]);
            this.update(this.weightValueReused, gradient);
            this.sum_of_gradients[i] = this.weightValueReused.getSumOfGradients();
            return this.weightValueReused.get();
        }

        private void ensureCapacity(int index) {
            if (index >= this.sum_of_gradients.length) {
                int bits = MathUtils.bitsRequired(index);
                int newSize = (1 << bits) + 1;
                this.sum_of_gradients = Arrays.copyOf(this.sum_of_gradients, newSize);
            }
        }
    }

    @NotThreadSafe
    static final class AdamHD
    extends Optimizer.AdamHD {
        @Nonnull
        private final WeightValue.WeightValueParamsF2 weightValueReused = this.newWeightValue(0.0f);
        @Nonnull
        private float[] val_m;
        @Nonnull
        private float[] val_v;

        public AdamHD(int ndims, Map<String, String> options) {
            super(options);
            this.val_m = new float[ndims];
            this.val_v = new float[ndims];
        }

        @Override
        protected float update(@Nonnull Object feature, float weight, float gradient) {
            int i = HiveUtils.parseInt(feature);
            this.ensureCapacity(i);
            this.weightValueReused.set(weight);
            this.weightValueReused.setM(this.val_m[i]);
            this.weightValueReused.setV(this.val_v[i]);
            this.update(this.weightValueReused, gradient);
            this.val_m[i] = this.weightValueReused.getM();
            this.val_v[i] = this.weightValueReused.getV();
            return this.weightValueReused.get();
        }

        private void ensureCapacity(int index) {
            if (index >= this.val_m.length) {
                int bits = MathUtils.bitsRequired(index);
                int newSize = (1 << bits) + 1;
                this.val_m = Arrays.copyOf(this.val_m, newSize);
                this.val_v = Arrays.copyOf(this.val_v, newSize);
            }
        }
    }

    @NotThreadSafe
    static final class Eve
    extends Optimizer.Eve {
        @Nonnull
        private final WeightValue.WeightValueParamsF2 weightValueReused = this.newWeightValue(0.0f);
        @Nonnull
        private float[] val_m;
        @Nonnull
        private float[] val_v;

        public Eve(int ndims, Map<String, String> options) {
            super(options);
            this.val_m = new float[ndims];
            this.val_v = new float[ndims];
        }

        @Override
        protected float update(@Nonnull Object feature, float weight, float gradient) {
            int i = HiveUtils.parseInt(feature);
            this.ensureCapacity(i);
            this.weightValueReused.set(weight);
            this.weightValueReused.setM(this.val_m[i]);
            this.weightValueReused.setV(this.val_v[i]);
            this.update(this.weightValueReused, gradient);
            this.val_m[i] = this.weightValueReused.getM();
            this.val_v[i] = this.weightValueReused.getV();
            return this.weightValueReused.get();
        }

        private void ensureCapacity(int index) {
            if (index >= this.val_m.length) {
                int bits = MathUtils.bitsRequired(index);
                int newSize = (1 << bits) + 1;
                this.val_m = Arrays.copyOf(this.val_m, newSize);
                this.val_v = Arrays.copyOf(this.val_v, newSize);
            }
        }
    }

    @NotThreadSafe
    static final class Nadam
    extends Optimizer.Nadam {
        @Nonnull
        private final WeightValue.WeightValueParamsF2 weightValueReused = this.newWeightValue(0.0f);
        @Nonnull
        private float[] val_m;
        @Nonnull
        private float[] val_v;

        public Nadam(int ndims, Map<String, String> options) {
            super(options);
            this.val_m = new float[ndims];
            this.val_v = new float[ndims];
        }

        @Override
        protected float update(@Nonnull Object feature, float weight, float gradient) {
            int i = HiveUtils.parseInt(feature);
            this.ensureCapacity(i);
            this.weightValueReused.set(weight);
            this.weightValueReused.setM(this.val_m[i]);
            this.weightValueReused.setV(this.val_v[i]);
            this.update(this.weightValueReused, gradient);
            this.val_m[i] = this.weightValueReused.getM();
            this.val_v[i] = this.weightValueReused.getV();
            return this.weightValueReused.get();
        }

        private void ensureCapacity(int index) {
            if (index >= this.val_m.length) {
                int bits = MathUtils.bitsRequired(index);
                int newSize = (1 << bits) + 1;
                this.val_m = Arrays.copyOf(this.val_m, newSize);
                this.val_v = Arrays.copyOf(this.val_v, newSize);
            }
        }
    }

    @NotThreadSafe
    static final class Adam
    extends Optimizer.Adam {
        @Nonnull
        private final WeightValue.WeightValueParamsF2 weightValueReused = this.newWeightValue(0.0f);
        @Nonnull
        private float[] val_m;
        @Nonnull
        private float[] val_v;

        public Adam(int ndims, Map<String, String> options) {
            super(options);
            this.val_m = new float[ndims];
            this.val_v = new float[ndims];
        }

        @Override
        protected float update(@Nonnull Object feature, float weight, float gradient) {
            int i = HiveUtils.parseInt(feature);
            this.ensureCapacity(i);
            this.weightValueReused.set(weight);
            this.weightValueReused.setM(this.val_m[i]);
            this.weightValueReused.setV(this.val_v[i]);
            this.update(this.weightValueReused, gradient);
            this.val_m[i] = this.weightValueReused.getM();
            this.val_v[i] = this.weightValueReused.getV();
            return this.weightValueReused.get();
        }

        private void ensureCapacity(int index) {
            if (index >= this.val_m.length) {
                int bits = MathUtils.bitsRequired(index);
                int newSize = (1 << bits) + 1;
                this.val_m = Arrays.copyOf(this.val_m, newSize);
                this.val_v = Arrays.copyOf(this.val_v, newSize);
            }
        }
    }

    @NotThreadSafe
    static final class AdaDelta
    extends Optimizer.AdaDelta {
        @Nonnull
        private final WeightValue.WeightValueParamsF2 weightValueReused = this.newWeightValue(0.0f);
        @Nonnull
        private float[] sum_of_squared_gradients;
        @Nonnull
        private float[] sum_of_squared_delta_x;

        public AdaDelta(int ndims, Map<String, String> options) {
            super(options);
            this.sum_of_squared_gradients = new float[ndims];
            this.sum_of_squared_delta_x = new float[ndims];
        }

        @Override
        protected float update(@Nonnull Object feature, float weight, float gradient) {
            int i = HiveUtils.parseInt(feature);
            this.ensureCapacity(i);
            this.weightValueReused.set(weight);
            this.weightValueReused.setSumOfSquaredGradients(this.sum_of_squared_gradients[i]);
            this.weightValueReused.setSumOfSquaredDeltaX(this.sum_of_squared_delta_x[i]);
            this.update(this.weightValueReused, gradient);
            this.sum_of_squared_gradients[i] = this.weightValueReused.getSumOfSquaredGradients();
            this.sum_of_squared_delta_x[i] = this.weightValueReused.getSumOfSquaredDeltaX();
            return this.weightValueReused.get();
        }

        private void ensureCapacity(int index) {
            if (index >= this.sum_of_squared_gradients.length) {
                int bits = MathUtils.bitsRequired(index);
                int newSize = (1 << bits) + 1;
                this.sum_of_squared_gradients = Arrays.copyOf(this.sum_of_squared_gradients, newSize);
                this.sum_of_squared_delta_x = Arrays.copyOf(this.sum_of_squared_delta_x, newSize);
            }
        }
    }

    @NotThreadSafe
    static final class RMSpropGraves
    extends Optimizer.RMSpropGraves {
        @Nonnull
        private final WeightValue.WeightValueParamsF3 weightValueReused = this.newWeightValue(0.0f);
        @Nonnull
        private float[] sum_of_gradients;
        @Nonnull
        private float[] sum_of_squared_gradients;
        @Nonnull
        private float[] delta;

        public RMSpropGraves(int ndims, Map<String, String> options) {
            super(options);
            this.sum_of_gradients = new float[ndims];
            this.sum_of_squared_gradients = new float[ndims];
            this.delta = new float[ndims];
        }

        @Override
        protected float update(@Nonnull Object feature, float weight, float gradient) {
            int i = HiveUtils.parseInt(feature);
            this.ensureCapacity(i);
            this.weightValueReused.set(weight);
            this.weightValueReused.setSumOfGradients(this.sum_of_gradients[i]);
            this.weightValueReused.setSumOfSquaredGradients(this.sum_of_squared_gradients[i]);
            this.weightValueReused.setDelta(this.delta[i]);
            this.update(this.weightValueReused, gradient);
            this.sum_of_gradients[i] = this.weightValueReused.getSumOfGradients();
            this.sum_of_squared_gradients[i] = this.weightValueReused.getSumOfSquaredGradients();
            this.delta[i] = this.weightValueReused.getDelta();
            return this.weightValueReused.get();
        }

        private void ensureCapacity(int index) {
            if (index >= this.sum_of_gradients.length) {
                int bits = MathUtils.bitsRequired(index);
                int newSize = (1 << bits) + 1;
                this.sum_of_gradients = Arrays.copyOf(this.sum_of_gradients, newSize);
                this.sum_of_squared_gradients = Arrays.copyOf(this.sum_of_squared_gradients, newSize);
                this.delta = Arrays.copyOf(this.delta, newSize);
            }
        }
    }

    @NotThreadSafe
    static final class RMSprop
    extends Optimizer.RMSprop {
        @Nonnull
        private final WeightValue.WeightValueParamsF1 weightValueReused = this.newWeightValue(0.0f);
        @Nonnull
        private float[] sum_of_squared_gradients;

        public RMSprop(int ndims, Map<String, String> options) {
            super(options);
            this.sum_of_squared_gradients = new float[ndims];
        }

        @Override
        protected float update(@Nonnull Object feature, float weight, float gradient) {
            int i = HiveUtils.parseInt(feature);
            this.ensureCapacity(i);
            this.weightValueReused.set(weight);
            this.weightValueReused.setSumOfSquaredGradients(this.sum_of_squared_gradients[i]);
            this.update(this.weightValueReused, gradient);
            this.sum_of_squared_gradients[i] = this.weightValueReused.getSumOfSquaredGradients();
            return this.weightValueReused.get();
        }

        private void ensureCapacity(int index) {
            if (index >= this.sum_of_squared_gradients.length) {
                int bits = MathUtils.bitsRequired(index);
                int newSize = (1 << bits) + 1;
                this.sum_of_squared_gradients = Arrays.copyOf(this.sum_of_squared_gradients, newSize);
            }
        }
    }

    @NotThreadSafe
    static final class AdaGrad
    extends Optimizer.AdaGrad {
        @Nonnull
        private final WeightValue.WeightValueParamsF1 weightValueReused = this.newWeightValue(0.0f);
        @Nonnull
        private float[] sum_of_squared_gradients;

        public AdaGrad(int ndims, Map<String, String> options) {
            super(options);
            this.sum_of_squared_gradients = new float[ndims];
        }

        @Override
        protected float update(@Nonnull Object feature, float weight, float gradient) {
            int i = HiveUtils.parseInt(feature);
            this.ensureCapacity(i);
            this.weightValueReused.set(weight);
            this.weightValueReused.setSumOfSquaredGradients(this.sum_of_squared_gradients[i]);
            this.update(this.weightValueReused, gradient);
            this.sum_of_squared_gradients[i] = this.weightValueReused.getSumOfSquaredGradients();
            return this.weightValueReused.get();
        }

        private void ensureCapacity(int index) {
            if (index >= this.sum_of_squared_gradients.length) {
                int bits = MathUtils.bitsRequired(index);
                int newSize = (1 << bits) + 1;
                this.sum_of_squared_gradients = Arrays.copyOf(this.sum_of_squared_gradients, newSize);
            }
        }
    }

    @NotThreadSafe
    static final class Momentum
    extends Optimizer.Momentum {
        @Nonnull
        private final WeightValue.WeightValueParamsF1 weightValueReused = this.newWeightValue(0.0f);
        @Nonnull
        private float[] delta;

        public Momentum(int ndims, Map<String, String> options) {
            super(options);
            this.delta = new float[ndims];
        }

        @Override
        protected float update(@Nonnull Object feature, float weight, float gradient) {
            int i = HiveUtils.parseInt(feature);
            this.ensureCapacity(i);
            this.weightValueReused.set(weight);
            this.weightValueReused.setDelta(this.delta[i]);
            this.update(this.weightValueReused, gradient);
            this.delta[i] = this.weightValueReused.getDelta();
            return this.weightValueReused.get();
        }

        private void ensureCapacity(int index) {
            if (index >= this.delta.length) {
                int bits = MathUtils.bitsRequired(index);
                int newSize = (1 << bits) + 1;
                this.delta = Arrays.copyOf(this.delta, newSize);
            }
        }
    }
}

