/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.net.layers;

import deepnetts.core.DeepNetts;
import deepnetts.net.layers.AbstractLayer;
import deepnetts.net.layers.FullyConnectedLayer;
import deepnetts.net.layers.InputLayer;
import deepnetts.net.layers.MaxPoolingLayer;
import deepnetts.net.layers.activation.ActivationFunction;
import deepnetts.net.layers.activation.ActivationType;
import deepnetts.net.weights.RandomWeights;
import deepnetts.util.DeepNettsException;
import deepnetts.util.Tensor;
import deepnetts.util.Tensors;
import java.util.logging.Logger;

public final class ConvolutionalLayer
extends AbstractLayer {
    Tensor[] filters;
    Tensor[] deltaWeights;
    Tensor[] prevDeltaWeights;
    Tensor[] prevGradSums;
    int filterWidth;
    int filterHeight;
    int filterDepth;
    int stride = 1;
    int padding = 0;
    int fCenterX;
    int fCenterY;
    int[][][][] maxIdx;
    private static final Logger LOG = Logger.getLogger(DeepNetts.class.getName());

    public ConvolutionalLayer(int filterWidth, int filterHeight, int channels) {
        this.filterWidth = filterWidth;
        this.filterHeight = filterHeight;
        this.depth = channels;
        this.stride = 1;
        this.activationType = ActivationType.TANH;
        this.activation = ActivationFunction.create(this.activationType);
    }

    public ConvolutionalLayer(int filterWidth, int filterHeight, int channels, ActivationType activationType) {
        this.filterWidth = filterWidth;
        this.filterHeight = filterHeight;
        this.depth = channels;
        this.stride = 1;
        this.activationType = activationType;
        this.activation = ActivationFunction.create(activationType);
    }

    public ConvolutionalLayer(int filterWidth, int filterHeight, int channels, int stride, ActivationType activationType) {
        this.filterWidth = filterWidth;
        this.filterHeight = filterHeight;
        this.depth = channels;
        this.stride = stride;
        this.activationType = activationType;
        this.activation = ActivationFunction.create(activationType);
    }

    @Override
    public void init() {
        if (!(this.prevLayer instanceof InputLayer || this.prevLayer instanceof ConvolutionalLayer || this.prevLayer instanceof MaxPoolingLayer)) {
            throw new DeepNettsException("Illegal architecture: convolutional layer can be used only after input, convolutional or maxpooling layer");
        }
        this.inputs = this.prevLayer.outputs;
        this.width = this.prevLayer.getWidth() / this.stride;
        this.height = this.prevLayer.getHeight() / this.stride;
        this.fCenterX = (this.filterWidth - 1) / 2;
        this.fCenterY = (this.filterHeight - 1) / 2;
        this.outputs = new Tensor(this.height, this.width, this.depth);
        this.deltas = new Tensor(this.height, this.width, this.depth);
        this.filterDepth = this.prevLayer.getDepth();
        this.filters = new Tensor[this.depth];
        this.deltaWeights = new Tensor[this.depth];
        this.prevDeltaWeights = new Tensor[this.depth];
        this.prevGradSums = new Tensor[this.depth];
        int inputCount = (this.filterWidth * this.filterHeight + 1) * this.filterDepth;
        for (int ch = 0; ch < this.filters.length; ++ch) {
            this.filters[ch] = new Tensor(this.filterHeight, this.filterWidth, this.filterDepth);
            RandomWeights.uniform(this.filters[ch].getValues(), inputCount);
            this.deltaWeights[ch] = new Tensor(this.filterHeight, this.filterWidth, this.filterDepth);
            this.prevDeltaWeights[ch] = new Tensor(this.filterHeight, this.filterWidth, this.filterDepth);
            this.prevGradSums[ch] = new Tensor(this.filterHeight, this.filterWidth, this.filterDepth);
        }
        this.biases = new float[this.depth];
        this.deltaBiases = new float[this.depth];
        this.prevDeltaBiases = new float[this.depth];
        this.prevBiasSqrSum = new Tensor(this.depth);
        Tensor.fill(this.biases, 0.1f);
    }

    @Override
    public void forward() {
        for (int ch = 0; ch < this.depth; ++ch) {
            this.forwardForChannel(ch);
        }
    }

    private void forwardForChannel(int ch) {
        int outRow = 0;
        int outCol = 0;
        for (int inRow = 0; inRow < this.inputs.getRows(); inRow += this.stride) {
            outCol = 0;
            for (int inCol = 0; inCol < this.inputs.getCols(); inCol += this.stride) {
                this.outputs.set(outRow, outCol, ch, this.biases[ch]);
                for (int fz = 0; fz < this.filterDepth; ++fz) {
                    for (int fr = 0; fr < this.filterHeight; ++fr) {
                        for (int fc = 0; fc < this.filterWidth; ++fc) {
                            int cr = inRow + (fr - this.fCenterY);
                            int cc = inCol + (fc - this.fCenterX);
                            if (cr < 0 || cr >= this.inputs.getRows() || cc < 0 || cc >= this.inputs.getCols()) continue;
                            float out = this.inputs.get(cr, cc, fz) * this.filters[ch].get(fr, fc, fz);
                            this.outputs.add(outRow, outCol, ch, out);
                        }
                    }
                }
                float out = this.activation.getValue(this.outputs.get(outRow, outCol, ch));
                this.outputs.set(outRow, outCol, ch, out);
                ++outCol;
            }
            ++outRow;
        }
    }

    @Override
    public void backward() {
        if (this.nextLayer instanceof FullyConnectedLayer) {
            this.backwardFromFullyConnected();
        }
        if (this.nextLayer instanceof MaxPoolingLayer) {
            this.backwardFromMaxPooling();
        }
        if (this.nextLayer instanceof ConvolutionalLayer) {
            this.backwardFromConvolutional();
        }
    }

    private void backwardFromFullyConnected() {
        this.deltas.fill(0.0f);
        for (int ch = 0; ch < this.depth; ++ch) {
            this.backwardFromFullyConnectedForChannel(ch);
        }
    }

    private void backwardFromFullyConnectedForChannel(int ch) {
        for (int row = 0; row < this.height; ++row) {
            for (int col = 0; col < this.width; ++col) {
                float actDerivative = this.activation.getPrime(this.outputs.get(row, col, ch));
                for (int ndC = 0; ndC < this.nextLayer.deltas.getCols(); ++ndC) {
                    float delta = this.nextLayer.deltas.get(ndC) * this.nextLayer.weights.get(col, row, ch, ndC) * actDerivative;
                    this.deltas.add(row, col, ch, delta);
                }
            }
        }
        this.calculateDeltaWeightsForChannel(ch);
    }

    private void backwardFromMaxPooling() {
        MaxPoolingLayer nextPoolLayer = (MaxPoolingLayer)this.nextLayer;
        this.maxIdx = nextPoolLayer.maxIdx;
        this.deltas.fill(0.0f);
        for (int ch = 0; ch < this.depth; ++ch) {
            this.backwardFromMaxPoolingForChannel(ch);
        }
    }

    private void backwardFromMaxPoolingForChannel(int ch) {
        for (int dr = 0; dr < this.nextLayer.deltas.getRows(); ++dr) {
            for (int dc = 0; dc < this.nextLayer.deltas.getCols(); ++dc) {
                float nextLayerDelta = this.nextLayer.deltas.get(dr, dc, ch);
                int maxR = this.maxIdx[ch][dr][dc][0];
                int maxC = this.maxIdx[ch][dr][dc][1];
                float derivative = this.activation.getPrime(this.outputs.get(maxR, maxC, ch));
                this.deltas.set(maxR, maxC, ch, nextLayerDelta * derivative);
            }
        }
        this.calculateDeltaWeightsForChannel(ch);
    }

    private void backwardFromConvolutional() {
        this.deltas.fill(0.0f);
        for (int ch = 0; ch < this.depth; ++ch) {
            this.backwardFromConvolutionalForChannel(ch);
        }
    }

    private void backwardFromConvolutionalForChannel(int fz) {
        ConvolutionalLayer nextConvLayer = (ConvolutionalLayer)this.nextLayer;
        int filterCenterX = (nextConvLayer.filterWidth - 1) / 2;
        int filterCenterY = (nextConvLayer.filterHeight - 1) / 2;
        for (int ndZ = 0; ndZ < this.nextLayer.deltas.getDepth(); ++ndZ) {
            for (int ndRow = 0; ndRow < this.nextLayer.deltas.getRows(); ++ndRow) {
                for (int ndCol = 0; ndCol < this.nextLayer.deltas.getCols(); ++ndCol) {
                    float nextLayerDelta = this.nextLayer.deltas.get(ndRow, ndCol, ndZ);
                    for (int fr = 0; fr < nextConvLayer.filterHeight; ++fr) {
                        for (int fc = 0; fc < nextConvLayer.filterWidth; ++fc) {
                            int row = ndRow * nextConvLayer.stride + (fr - filterCenterY);
                            int col = ndCol * nextConvLayer.stride + (fc - filterCenterX);
                            if (row < 0 || row >= this.outputs.getRows() || col < 0 || col >= this.outputs.getCols()) continue;
                            float derivative = this.activation.getPrime(this.outputs.get(row, col, fz));
                            this.deltas.add(row, col, fz, nextLayerDelta * nextConvLayer.filters[ndZ].get(fr, fc, fz) * derivative);
                        }
                    }
                    this.deltas.div(nextConvLayer.filterWidth * nextConvLayer.filterHeight * nextConvLayer.filterDepth);
                }
            }
        }
        this.calculateDeltaWeightsForChannel(fz);
    }

    private void calculateDeltaWeightsForChannel(int ch) {
        if (!this.batchMode) {
            this.deltaWeights[ch].fill(0.0f);
            this.deltaBiases[ch] = 0.0f;
        }
        float divisor = this.width * this.height;
        for (int deltaRow = 0; deltaRow < this.deltas.getRows(); ++deltaRow) {
            for (int deltaCol = 0; deltaCol < this.deltas.getCols(); ++deltaCol) {
                for (int fz = 0; fz < this.filterDepth; ++fz) {
                    for (int fr = 0; fr < this.filterHeight; ++fr) {
                        for (int fc = 0; fc < this.filterWidth; ++fc) {
                            int inRow = deltaRow * this.stride + fr - this.fCenterY;
                            int inCol = deltaCol * this.stride + fc - this.fCenterX;
                            if (inRow < 0 || inRow >= this.inputs.getRows() || inCol < 0 || inCol >= this.inputs.getCols()) continue;
                            float input = this.inputs.get(inRow, inCol, fz);
                            float grad = this.deltas.get(deltaRow, deltaCol, ch) * input;
                            float deltaWeight = 0.0f;
                            switch (this.optimizerType) {
                                case SGD: {
                                    deltaWeight = this.optim.calculateDeltaWeight(grad, new int[0]);
                                    break;
                                }
                                default: {
                                    throw new DeepNettsException("Optimizer not supported!");
                                }
                            }
                            this.deltaWeights[ch].add(fr, fc, fz, deltaWeight /= divisor);
                        }
                    }
                }
                float deltaBias = 0.0f;
                switch (this.optimizerType) {
                    case SGD: {
                        deltaBias = this.optim.calculateDeltaBias(this.deltas.get(deltaRow, deltaCol, ch), deltaCol);
                        break;
                    }
                    default: {
                        throw new DeepNettsException("Optimizer not supported!");
                    }
                }
                int n = ch;
                this.deltaBiases[n] = this.deltaBiases[n] / divisor;
                int n2 = ch;
                this.deltaBiases[n2] = this.deltaBiases[n2] + deltaBias;
            }
        }
    }

    @Override
    public void applyWeightChanges() {
        if (this.batchMode) {
            Tensors.div(this.deltaBiases, this.batchSize);
        }
        Tensor.copy(this.deltaBiases, this.prevDeltaBiases);
        for (int ch = 0; ch < this.depth; ++ch) {
            if (this.batchMode) {
                this.deltaWeights[ch].div(this.batchSize);
            }
            Tensor.copy(this.deltaWeights[ch], this.prevDeltaWeights[ch]);
            this.filters[ch].add(this.deltaWeights[ch]);
            int n = ch;
            this.biases[n] = this.biases[n] + this.deltaBiases[ch];
            if (!this.batchMode) continue;
            this.deltaWeights[ch].fill(0.0f);
        }
        if (this.batchMode) {
            Tensor.fill(this.deltaBiases, 0.0f);
        }
    }

    public Tensor[] getFilters() {
        return this.filters;
    }

    public void setFilters(Tensor[] filters) {
        this.filters = filters;
    }

    public void setFilters(String filtersStr) {
        String[] strVals = filtersStr.split(";");
        int filterSize = this.filterWidth * this.filterHeight * this.filterDepth;
        for (int i = 0; i < this.filters.length; ++i) {
            float[] filterValues = new float[filterSize];
            String[] vals = strVals[i].split(",");
            for (int k = 0; k < filterSize; ++k) {
                filterValues[k] = Float.parseFloat(vals[k]);
            }
            this.filters[i].setValues(filterValues);
        }
    }

    public int getFilterWidth() {
        return this.filterWidth;
    }

    public int getFilterHeight() {
        return this.filterHeight;
    }

    public int getFilterDepth() {
        return this.filterDepth;
    }

    public int getStride() {
        return this.stride;
    }

    public Tensor[] getFilterDeltaWeights() {
        return this.deltaWeights;
    }

    public String toString() {
        return "Convolutional Layer { filter width:" + this.filterWidth + ", filter height: " + this.filterHeight + ", channels: " + this.depth + ", stride: " + this.stride + ", activation: " + this.activationType.name() + "}";
    }
}

