/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.net.loss;

import deepnetts.net.NeuralNetwork;
import deepnetts.net.loss.LossFunction;
import java.io.Serializable;

public final class MeanSquaredErrorLoss
implements LossFunction,
Serializable {
    private final float[] outputError;
    private float totalError = 0.0f;
    private int patternCount = 0;
    private float regularizationSum = 0.0f;

    public MeanSquaredErrorLoss(NeuralNetwork neuralNet) {
        this.outputError = new float[neuralNet.getOutputLayer().getWidth()];
    }

    @Override
    public float[] addPatternError(float[] predictedOutput, float[] targetOutput) {
        for (int i = 0; i < predictedOutput.length; ++i) {
            this.outputError[i] = predictedOutput[i] - targetOutput[i];
            this.totalError += this.outputError[i] * this.outputError[i];
        }
        ++this.patternCount;
        return this.outputError;
    }

    @Override
    public void addRegularizationSum(float regSum) {
    }

    @Override
    public float getTotal() {
        return this.totalError / (float)(2 * this.patternCount * this.outputError.length) + this.regularizationSum;
    }

    @Override
    public void reset() {
        this.totalError = 0.0f;
        this.patternCount = 0;
        this.regularizationSum = 0.0f;
    }
}

