/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.net.layers;

import deepnetts.core.DeepNetts;
import deepnetts.net.layers.AbstractLayer;
import deepnetts.net.layers.ConvolutionalLayer;
import deepnetts.net.layers.InputLayer;
import deepnetts.net.layers.MaxPoolingLayer;
import deepnetts.net.layers.OutputLayer;
import deepnetts.net.layers.activation.ActivationType;
import deepnetts.net.weights.RandomWeights;
import deepnetts.util.DeepNettsException;
import deepnetts.util.Tensor;
import deepnetts.util.Tensors;
import java.util.Arrays;
import java.util.logging.Logger;

public final class FullyConnectedLayer
extends AbstractLayer {
    private static final Logger LOG = Logger.getLogger(DeepNetts.class.getName());

    public FullyConnectedLayer(int width) {
        this.width = width;
        this.height = 1;
        this.depth = 1;
        this.setActivationType(ActivationType.SIGMOID);
    }

    public FullyConnectedLayer(int width, ActivationType actType) {
        this(width);
        this.setActivationType(actType);
    }

    @Override
    public void init() {
        if (!(this.prevLayer instanceof InputLayer || this.prevLayer instanceof FullyConnectedLayer || this.prevLayer instanceof MaxPoolingLayer || this.prevLayer instanceof ConvolutionalLayer)) {
            throw new DeepNettsException("Bad network architecture! Fully Connected Layer can be connected only to Input, FullyConnected, Maxpooling or Convolutional layer as previous layer.");
        }
        if (!(this.nextLayer instanceof FullyConnectedLayer) && !(this.nextLayer instanceof OutputLayer)) {
            throw new DeepNettsException("Bad network architecture! Fully Connected Layer can only be connected only to Fully Connected Layer or Output layer as next layer");
        }
        this.inputs = this.prevLayer.outputs;
        this.outputs = new Tensor(this.width);
        this.deltas = new Tensor(this.width);
        if (this.prevLayer instanceof FullyConnectedLayer || this.prevLayer instanceof InputLayer && this.prevLayer.height == 1 && this.prevLayer.depth == 1) {
            this.weights = new Tensor(this.prevLayer.width, this.width);
            this.deltaWeights = new Tensor(this.prevLayer.width, this.width);
            this.gradients = new Tensor(this.prevLayer.width, this.width);
            this.prevDeltaWeights = new Tensor(this.prevLayer.width, this.width);
            this.prevGradSqrSum = new Tensor(this.prevLayer.width, this.width);
            this.prevDeltaWeightSqrSum = new Tensor(this.prevLayer.width, this.width);
            this.prevBiasSqrSum = new Tensor(this.width);
            this.prevDeltaBiasSqrSum = new Tensor(this.width);
            if (this.activationType == ActivationType.RELU || this.activationType == ActivationType.LEAKY_RELU) {
                RandomWeights.he(this.weights.getValues(), this.outputs.size());
            } else {
                RandomWeights.xavier(this.weights.getValues(), this.prevLayer.width, this.width);
            }
        } else if (this.prevLayer instanceof MaxPoolingLayer || this.prevLayer instanceof ConvolutionalLayer || this.prevLayer instanceof InputLayer) {
            this.weights = new Tensor(this.prevLayer.width, this.prevLayer.height, this.prevLayer.depth, this.width);
            this.deltaWeights = new Tensor(this.prevLayer.width, this.prevLayer.height, this.prevLayer.depth, this.width);
            this.gradients = new Tensor(this.prevLayer.width, this.prevLayer.height, this.prevLayer.depth, this.width);
            this.prevDeltaWeights = new Tensor(this.prevLayer.width, this.prevLayer.height, this.prevLayer.depth, this.width);
            this.prevGradSqrSum = new Tensor(this.prevLayer.width, this.prevLayer.height, this.prevLayer.depth, this.width);
            this.prevBiasSqrSum = new Tensor(this.width);
            this.prevDeltaWeightSqrSum = new Tensor(this.prevLayer.width, this.prevLayer.height, this.prevLayer.depth, this.width);
            this.prevDeltaBiasSqrSum = new Tensor(this.width);
            int totalInputs = this.prevLayer.getWidth() * this.prevLayer.getHeight() * this.prevLayer.getDepth();
            if (this.activationType == ActivationType.RELU || this.activationType == ActivationType.LEAKY_RELU) {
                RandomWeights.he(this.weights.getValues(), totalInputs);
            } else {
                RandomWeights.xavier(this.weights.getValues(), totalInputs, this.width);
            }
        }
        this.biases = new float[this.width];
        this.deltaBiases = new float[this.width];
        this.prevDeltaBiases = new float[this.width];
        if (this.activationType == ActivationType.RELU || this.activationType == ActivationType.LEAKY_RELU) {
            Tensor.fill(this.biases, 0.1f);
        } else {
            Tensor.fill(this.biases, 0.1f);
        }
    }

    @Override
    public void forward() {
        this.outputs.copyFrom(this.biases);
        if (this.prevLayer instanceof FullyConnectedLayer || this.prevLayer instanceof InputLayer && this.prevLayer.height == 1 && this.prevLayer.depth == 1) {
            for (int outCol = 0; outCol < this.outputs.getCols(); ++outCol) {
                for (int inCol = 0; inCol < this.inputs.getCols(); ++inCol) {
                    this.outputs.add(outCol, this.inputs.get(inCol) * this.weights.get(inCol, outCol));
                }
                this.outputs.set(outCol, this.activation.getValue(this.outputs.get(outCol)));
            }
        } else if (this.prevLayer instanceof MaxPoolingLayer || this.prevLayer instanceof ConvolutionalLayer || this.prevLayer instanceof InputLayer) {
            this.forwardFrom3DLayer();
        }
    }

    private void forwardFrom3DLayer() {
        for (int outCol = 0; outCol < this.outputs.getCols(); ++outCol) {
            this.forwardFrom3DLayerForCell(outCol);
        }
    }

    private void forwardFrom3DLayerForCell(int outCol) {
        for (int inDepth = 0; inDepth < this.inputs.getDepth(); ++inDepth) {
            for (int inRow = 0; inRow < this.inputs.getRows(); ++inRow) {
                for (int inCol = 0; inCol < this.inputs.getCols(); ++inCol) {
                    this.outputs.add(outCol, this.inputs.get(inRow, inCol, inDepth) * this.weights.get(inCol, inRow, inDepth, outCol));
                }
            }
        }
        this.outputs.set(outCol, this.activation.getValue(this.outputs.get(outCol)));
    }

    @Override
    public void backward() {
        int deltaCol;
        if (!this.batchMode) {
            this.deltaWeights.fill(0.0f);
            Arrays.fill(this.deltaBiases, 0.0f);
        }
        this.deltas.fill(0.0f);
        for (deltaCol = 0; deltaCol < this.deltas.getCols(); ++deltaCol) {
            for (int ndCol = 0; ndCol < this.nextLayer.deltas.getCols(); ++ndCol) {
                this.deltas.add(deltaCol, this.nextLayer.deltas.get(ndCol) * this.nextLayer.weights.get(deltaCol, ndCol));
            }
            float delta = this.deltas.get(deltaCol) * this.activation.getPrime(this.outputs.get(deltaCol));
            this.deltas.set(deltaCol, delta);
        }
        if (this.prevLayer instanceof FullyConnectedLayer || this.prevLayer instanceof InputLayer && this.prevLayer.height == 1 && this.prevLayer.depth == 1) {
            deltaCol = 0;
            while (deltaCol < this.deltas.getCols()) {
                for (int inCol = 0; inCol < this.inputs.getCols(); ++inCol) {
                    float grad = this.deltas.get(deltaCol) * this.inputs.get(inCol);
                    this.gradients.set(inCol, deltaCol, grad);
                    float deltaWeight = this.optim.calculateDeltaWeight(grad, inCol, deltaCol);
                    this.deltaWeights.add(inCol, deltaCol, deltaWeight);
                }
                float deltaBias = this.optim.calculateDeltaBias(this.deltas.get(deltaCol), deltaCol);
                int n = deltaCol++;
                this.deltaBiases[n] = this.deltaBiases[n] + deltaBias;
            }
        } else if (this.prevLayer instanceof InputLayer || this.prevLayer instanceof ConvolutionalLayer || this.prevLayer instanceof MaxPoolingLayer) {
            this.backwardTo3DLayer();
        }
    }

    private void backwardTo3DLayer() {
        for (int deltaCol = 0; deltaCol < this.deltas.getCols(); ++deltaCol) {
            this.backwardTo3DLayerForCell(deltaCol);
        }
    }

    private void backwardTo3DLayerForCell(int deltaCol) {
        for (int inDepth = 0; inDepth < this.inputs.getDepth(); ++inDepth) {
            for (int inCol = 0; inCol < this.inputs.getCols(); ++inCol) {
                for (int inRow = 0; inRow < this.inputs.getRows(); ++inRow) {
                    float grad = this.deltas.get(deltaCol) * this.inputs.get(inRow, inCol, inDepth);
                    this.gradients.set(inCol, inRow, inDepth, deltaCol, grad);
                    float deltaWeight = this.optim.calculateDeltaWeight(grad, inCol, inRow, inDepth, deltaCol);
                    this.deltaWeights.add(inCol, inRow, inDepth, deltaCol, deltaWeight);
                }
            }
        }
        float deltaBias = this.optim.calculateDeltaBias(this.deltas.get(deltaCol), deltaCol);
        int n = deltaCol;
        this.deltaBiases[n] = this.deltaBiases[n] + deltaBias;
    }

    @Override
    public void applyWeightChanges() {
        if (this.batchMode) {
            this.deltaWeights.div(this.batchSize);
            Tensors.div(this.deltaBiases, this.batchSize);
        }
        Tensor.copy(this.deltaWeights, this.prevDeltaWeights);
        Tensor.copy(this.deltaBiases, this.prevDeltaBiases);
        this.weights.add(this.deltaWeights);
        Tensors.add(this.biases, this.deltaBiases);
        if (this.batchMode) {
            this.deltaWeights.fill(0.0f);
            Tensor.fill(this.deltaBiases, 0.0f);
        }
    }

    public String toString() {
        return "Fully Connected Layer { width:" + this.width + " activation:" + this.activationType.name() + "}";
    }
}

