/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.net.layers;

import deepnetts.net.layers.AbstractLayer;
import deepnetts.net.layers.activation.ActivationType;
import deepnetts.net.loss.LossType;
import deepnetts.net.weights.RandomWeights;
import deepnetts.util.Tensor;
import deepnetts.util.Tensors;
import java.util.Arrays;

public class OutputLayer
extends AbstractLayer {
    protected float[] outputErrors;
    protected final String[] labels;
    protected LossType lossType;

    public OutputLayer(int width) {
        this.width = width;
        this.height = 1;
        this.depth = 1;
        this.labels = new String[this.depth];
        for (int i = 0; i < this.depth; ++i) {
            this.labels[i] = "out" + i;
        }
        this.setActivationType(ActivationType.SIGMOID);
    }

    public OutputLayer(int width, ActivationType actType) {
        this.width = width;
        this.height = 1;
        this.depth = 1;
        this.labels = new String[this.depth];
        for (int i = 0; i < this.depth; ++i) {
            this.labels[i] = "Output" + i;
        }
        this.setActivationType(actType);
    }

    public OutputLayer(String[] outputLabels) {
        this.width = outputLabels.length;
        this.height = 1;
        this.depth = 1;
        this.labels = outputLabels;
        this.setActivationType(ActivationType.SIGMOID);
    }

    public OutputLayer(String[] outputLabels, ActivationType actType) {
        this(outputLabels);
        this.setActivationType(actType);
    }

    public final void setOutputErrors(float[] outputErrors) {
        this.outputErrors = outputErrors;
    }

    public final float[] getOutputErrors() {
        return this.outputErrors;
    }

    public final LossType getLossType() {
        return this.lossType;
    }

    public void setLossType(LossType lossType) {
        this.lossType = lossType;
    }

    @Override
    public void init() {
        this.inputs = this.prevLayer.outputs;
        this.outputs = new Tensor(this.width);
        this.outputErrors = new float[this.width];
        this.deltas = new Tensor(this.width);
        int prevLayerWidth = this.prevLayer.getWidth();
        this.weights = new Tensor(prevLayerWidth, this.width);
        this.gradients = new Tensor(prevLayerWidth, this.width);
        this.deltaWeights = new Tensor(prevLayerWidth, this.width);
        this.prevDeltaWeights = new Tensor(prevLayerWidth, this.width);
        RandomWeights.xavier(this.weights.getValues(), prevLayerWidth, this.width);
        this.biases = new float[this.width];
        this.deltaBiases = new float[this.width];
        this.prevDeltaBiases = new float[this.width];
        RandomWeights.randomize(this.biases);
    }

    @Override
    public void forward() {
        this.outputs.copyFrom(this.biases);
        for (int outCol = 0; outCol < this.outputs.getCols(); ++outCol) {
            for (int inCol = 0; inCol < this.inputs.getCols(); ++inCol) {
                this.outputs.add(outCol, this.inputs.get(inCol) * this.weights.get(inCol, outCol));
            }
        }
        this.outputs.apply(this.activation::getValue);
    }

    @Override
    public void backward() {
        if (!this.batchMode) {
            this.deltaWeights.fill(0.0f);
            Arrays.fill(this.deltaBiases, 0.0f);
        }
        int deltaCol = 0;
        while (deltaCol < this.deltas.getCols()) {
            if (this.lossType == LossType.MEAN_SQUARED_ERROR) {
                float delta = this.outputErrors[deltaCol] * this.activation.getPrime(this.outputs.get(deltaCol));
                this.deltas.set(deltaCol, delta);
            } else if (this.activationType == ActivationType.SIGMOID && this.lossType == LossType.CROSS_ENTROPY) {
                this.deltas.set(deltaCol, this.outputErrors[deltaCol]);
            }
            for (int inCol = 0; inCol < this.inputs.getCols(); ++inCol) {
                float grad = this.deltas.get(deltaCol) * this.inputs.get(inCol);
                this.gradients.set(inCol, deltaCol, grad);
                float deltaWeight = this.optim.calculateDeltaWeight(grad, inCol, deltaCol);
                this.deltaWeights.add(inCol, deltaCol, deltaWeight);
            }
            float deltaBias = this.optim.calculateDeltaBias(this.deltas.get(deltaCol), deltaCol);
            int n = deltaCol++;
            this.deltaBiases[n] = this.deltaBiases[n] + deltaBias;
        }
    }

    @Override
    public void applyWeightChanges() {
        if (this.batchMode) {
            this.deltaWeights.div(this.batchSize);
            Tensors.div(this.deltaBiases, this.batchSize);
        }
        Tensor.copy(this.deltaWeights, this.prevDeltaWeights);
        this.weights.add(this.deltaWeights);
        Tensor.copy(this.deltaBiases, this.prevDeltaBiases);
        Tensors.add(this.biases, this.deltaBiases);
        if (this.batchMode) {
            this.deltaWeights.fill(0.0f);
            Tensor.fill(this.deltaBiases, 0.0f);
        }
    }

    public String toString() {
        return "Output Layer { width:" + this.width + ", activation:" + this.activationType.name() + "}";
    }
}

