/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.nnet.learning;

import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.Neuron;
import org.neuroph.core.Weight;
import org.neuroph.nnet.learning.BackPropagation;

public class MomentumBackpropagation
extends BackPropagation {
    private static final long serialVersionUID = 1L;
    protected double momentum = 0.25;

    @Override
    public void updateNeuronWeights(Neuron neuron) {
        for (Connection connection : neuron.getInputConnections()) {
            double input = connection.getInput();
            if (input == 0.0) continue;
            double neuronError = neuron.getError();
            Weight weight = connection.getWeight();
            MomentumWeightTrainingData weightTrainingData = (MomentumWeightTrainingData)weight.getTrainingData();
            double previousWeightValue = weightTrainingData.previousValue;
            double weightChange = this.learningRate * neuronError * input + this.momentum * (weight.value - previousWeightValue);
            weightTrainingData.previousValue = weight.value;
            if (!this.isInBatchMode()) {
                weight.weightChange = weightChange;
                weight.value += weightChange;
                continue;
            }
            weight.weightChange += weightChange;
        }
    }

    public double getMomentum() {
        return this.momentum;
    }

    public void setMomentum(double momentum) {
        this.momentum = momentum;
    }

    @Override
    protected void onStart() {
        super.onStart();
        for (Layer layer : this.neuralNetwork.getLayers()) {
            for (Neuron neuron : layer.getNeurons()) {
                for (Connection connection : neuron.getInputConnections()) {
                    connection.getWeight().setTrainingData(new MomentumWeightTrainingData());
                }
            }
        }
    }

    public static class MomentumWeightTrainingData {
        public double previousValue;
    }
}

