/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.nnet.learning;

import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.Neuron;
import org.neuroph.core.transfer.TransferFunction;
import org.neuroph.nnet.learning.LMS;

public class BackPropagation
extends LMS {
    private static final long serialVersionUID = 1L;

    @Override
    protected void updateNetworkWeights(double[] outputError) {
        this.calculateErrorAndUpdateOutputNeurons(outputError);
        this.calculateErrorAndUpdateHiddenNeurons();
    }

    protected void calculateErrorAndUpdateOutputNeurons(double[] outputError) {
        int i = 0;
        for (Neuron neuron : this.neuralNetwork.getOutputNeurons()) {
            if (outputError[i] == 0.0) {
                neuron.setError(0.0);
                ++i;
                continue;
            }
            TransferFunction transferFunction = neuron.getTransferFunction();
            double neuronInput = neuron.getNetInput();
            double delta = outputError[i] * transferFunction.getDerivative(neuronInput);
            neuron.setError(delta);
            this.updateNeuronWeights(neuron);
            ++i;
        }
    }

    protected void calculateErrorAndUpdateHiddenNeurons() {
        Layer[] layers = this.neuralNetwork.getLayers();
        for (int layerIdx = layers.length - 2; layerIdx > 0; --layerIdx) {
            for (Neuron neuron : layers[layerIdx].getNeurons()) {
                double neuronError = this.calculateHiddenNeuronError(neuron);
                neuron.setError(neuronError);
                this.updateNeuronWeights(neuron);
            }
        }
    }

    protected double calculateHiddenNeuronError(Neuron neuron) {
        double deltaSum = 0.0;
        for (Connection connection : neuron.getOutConnections()) {
            double delta = connection.getToNeuron().getError() * connection.getWeight().value;
            deltaSum += delta;
        }
        TransferFunction transferFunction = neuron.getTransferFunction();
        double netInput = neuron.getNetInput();
        double f1 = transferFunction.getDerivative(netInput);
        double neuronError = f1 * deltaSum;
        return neuronError;
    }
}

