/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.net.layers;

import deepnetts.net.layers.OutputLayer;
import deepnetts.net.layers.activation.ActivationType;
import deepnetts.net.weights.RandomWeights;
import deepnetts.util.Tensor;
import java.util.Arrays;

public class SoftmaxOutputLayer
extends OutputLayer {
    private static final long serialVersionUID = -5557183169491335524L;

    public SoftmaxOutputLayer(int width) {
        super(width);
        this.setActivationType(ActivationType.SOFTMAX);
    }

    public SoftmaxOutputLayer(String[] labels) {
        super(labels);
        this.setActivationType(ActivationType.SOFTMAX);
    }

    @Override
    public void init() {
        this.inputs = this.prevLayer.outputs;
        this.outputs = new Tensor(this.width);
        this.outputErrors = new float[this.width];
        this.deltas = new Tensor(this.width);
        int prevLayerWidth = this.prevLayer.getWidth();
        this.weights = new Tensor(prevLayerWidth, this.width);
        this.deltaWeights = new Tensor(prevLayerWidth, this.width);
        this.gradients = new Tensor(prevLayerWidth, this.width);
        this.prevDeltaWeights = new Tensor(prevLayerWidth, this.width);
        RandomWeights.xavier(this.weights.getValues(), prevLayerWidth, this.width);
        this.biases = new float[this.width];
        this.deltaBiases = new float[this.width];
        this.prevDeltaBiases = new float[this.width];
        RandomWeights.gaussian(this.biases, 0.1f, 0.05f);
    }

    @Override
    public void forward() {
        float maxWs = Float.NEGATIVE_INFINITY;
        for (int outCol = 0; outCol < this.outputs.getCols(); ++outCol) {
            this.outputs.set(outCol, this.biases[outCol]);
            for (int inCol = 0; inCol < this.inputs.getCols(); ++inCol) {
                this.outputs.add(outCol, this.inputs.get(inCol) * this.weights.get(inCol, outCol));
            }
            if (!(this.outputs.get(outCol) > maxWs)) continue;
            maxWs = this.outputs.get(outCol);
        }
        float denSum = 0.0f;
        for (int col = 0; col < this.outputs.getCols(); ++col) {
            this.outputs.set(col, (float)Math.exp(this.outputs.get(col) - maxWs));
            denSum += this.outputs.get(col);
        }
        this.outputs.div(denSum);
    }

    @Override
    public void backward() {
        if (!this.batchMode) {
            this.deltaWeights.fill(0.0f);
            Arrays.fill(this.deltaBiases, 0.0f);
        }
        this.deltas.copyFrom(this.outputErrors);
        int outCol = 0;
        while (outCol < this.outputs.getCols()) {
            for (int inCol = 0; inCol < this.inputs.getCols(); ++inCol) {
                float grad = this.deltas.get(outCol) * this.inputs.get(inCol);
                this.gradients.set(inCol, outCol, grad);
                float deltaWeight = this.optim.calculateDeltaWeight(grad, inCol, outCol);
                this.deltaWeights.add(inCol, outCol, deltaWeight);
            }
            float deltaBias = this.optim.calculateDeltaBias(this.deltas.get(outCol), outCol);
            int n = outCol++;
            this.deltaBiases[n] = this.deltaBiases[n] + deltaBias;
        }
    }
}

