/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.net.loss;

import deepnetts.net.NeuralNetwork;
import deepnetts.net.loss.LossFunction;
import deepnetts.util.DeepNettsException;
import java.io.Serializable;

public class BinaryCrossEntropyLoss
implements LossFunction,
Serializable {
    private final float[] outputError;
    private float totalError;
    private int patternCount = 0;

    public BinaryCrossEntropyLoss(NeuralNetwork neuralNet) {
        if (neuralNet.getOutputLayer().getWidth() > 1) {
            throw new DeepNettsException("BinaryCrossEntropyLoss can be only used with networks with single sigmoid output!");
        }
        this.outputError = new float[1];
    }

    @Override
    public float[] addPatternError(float[] actual, float[] target) {
        this.outputError[0] = actual[0] - target[0];
        this.totalError += (float)((double)target[0] * Math.log(actual[0]) + (double)(1.0f - target[0]) * Math.log(1.0f - actual[0]));
        ++this.patternCount;
        return this.outputError;
    }

    @Override
    public void addRegularizationSum(float reg) {
        this.totalError += reg;
    }

    @Override
    public float getTotal() {
        return -this.totalError / (float)this.patternCount;
    }

    @Override
    public void reset() {
        this.totalError = 0.0f;
        this.patternCount = 0;
    }
}

