/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.net.train.opt;

import deepnetts.net.layers.AbstractLayer;
import deepnetts.net.train.opt.Optimizer;
import deepnetts.util.DeepNettsException;
import deepnetts.util.Tensor;
import java.io.Serializable;

public final class MomentumOptimizer
implements Optimizer,
Serializable {
    public static final int ROW_IDX = 0;
    public static final int COL_IDX = 1;
    private float momentum;
    private float learningRate;
    private final Tensor prevDeltaWeights;
    private final float[] prevDeltaBiases;
    AbstractLayer layer;

    public MomentumOptimizer(AbstractLayer layer) {
        this.layer = layer;
        this.learningRate = layer.getLearningRate();
        this.momentum = layer.getMomentum();
        this.prevDeltaWeights = layer.getPrevDeltaWeights();
        this.prevDeltaBiases = layer.getPrevDeltaBiases();
    }

    @Override
    public float calculateDeltaWeight(float grad, int ... idxs) {
        if (idxs.length == 2) {
            return -this.learningRate * grad + this.momentum * this.prevDeltaWeights.get(idxs[0], idxs[1]);
        }
        if (idxs.length == 4) {
            float dw = -this.learningRate * grad + this.momentum * this.prevDeltaWeights.get(idxs[0], idxs[1], idxs[2], idxs[3]);
            if (dw == Float.NaN) {
                throw new DeepNettsException("NaN in momentum!!!" + this.layer.getClass());
            }
            return dw;
        }
        return -this.learningRate * grad + this.momentum * this.prevDeltaWeights.get(idxs[0], idxs[1]);
    }

    @Override
    public float calculateDeltaBias(float gradient, int idx) {
        return -this.learningRate * gradient + this.momentum * this.prevDeltaBiases[idx];
    }
}

