/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.net.loss;

import deepnetts.net.NeuralNetwork;
import deepnetts.net.loss.LossFunction;
import java.io.Serializable;

public class CrossEntropyLoss
implements LossFunction,
Serializable {
    private final float[] outputError;
    private int targetIdx;
    private float totalError;
    private int patternCount = 0;

    public CrossEntropyLoss(NeuralNetwork neuralNet) {
        this.outputError = new float[neuralNet.getOutputLayer().getWidth()];
    }

    @Override
    public float[] addPatternError(float[] actualOutput, float[] targetOutput) {
        ++this.patternCount;
        for (int i = 0; i < actualOutput.length; ++i) {
            this.outputError[i] = actualOutput[i] - targetOutput[i];
            if (targetOutput[i] != 1.0f) continue;
            this.targetIdx = i;
        }
        this.totalError += (float)Math.log(actualOutput[this.targetIdx]);
        return this.outputError;
    }

    @Override
    public void addRegularizationSum(float reg) {
        this.totalError += reg;
    }

    @Override
    public float getTotal() {
        return -this.totalError / (float)this.patternCount;
    }

    @Override
    public void reset() {
        this.totalError = 0.0f;
        this.patternCount = 0;
    }
}

