/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.net.layers;

import deepnetts.core.DeepNetts;
import deepnetts.net.layers.AbstractLayer;
import deepnetts.net.layers.ConvolutionalLayer;
import deepnetts.net.layers.FullyConnectedLayer;
import deepnetts.util.Tensor;
import java.util.logging.Logger;

public final class MaxPoolingLayer
extends AbstractLayer {
    final int filterWidth;
    final int filterHeight;
    final int stride;
    int[][][][] maxIdx;
    private static final Logger LOG = Logger.getLogger(DeepNetts.class.getName());

    public MaxPoolingLayer(int filterWidth, int filterHeight, int stride) {
        this.filterWidth = filterWidth;
        this.filterHeight = filterHeight;
        this.stride = stride;
    }

    @Override
    public final void init() {
        if (!(this.prevLayer instanceof ConvolutionalLayer)) {
            throw new RuntimeException("Illegal network architecture! MaxPooling can be only after convolutional layer!");
        }
        this.inputs = this.prevLayer.outputs;
        this.width = (this.inputs.getCols() - this.filterWidth) / this.stride + 1;
        this.height = (this.inputs.getRows() - this.filterHeight) / this.stride + 1;
        this.depth = this.prevLayer.getDepth();
        this.outputs = new Tensor(this.height, this.width, this.depth);
        this.deltas = new Tensor(this.height, this.width, this.depth);
        this.maxIdx = new int[this.depth][this.height][this.width][2];
    }

    @Override
    public void forward() {
        for (int ch = 0; ch < this.depth; ++ch) {
            this.forwardForChannel(ch);
        }
    }

    private void forwardForChannel(int ch) {
        int maxC = -1;
        int maxR = -1;
        int outCol = 0;
        int outRow = 0;
        for (int inRow = 0; inRow < this.inputs.getRows() - this.filterHeight + 1; inRow += this.stride) {
            outCol = 0;
            for (int inCol = 0; inCol < this.inputs.getCols() - this.filterWidth + 1; inCol += this.stride) {
                float max = this.inputs.get(inRow, inCol, ch);
                maxC = inCol;
                maxR = inRow;
                for (int fr = 0; fr < this.filterHeight; ++fr) {
                    for (int fc = 0; fc < this.filterWidth; ++fc) {
                        if (!(max < this.inputs.get(inRow + fr, inCol + fc, ch))) continue;
                        maxR = inRow + fr;
                        maxC = inCol + fc;
                        max = this.inputs.get(maxR, maxC, ch);
                    }
                }
                this.maxIdx[ch][outRow][outCol][0] = maxR;
                this.maxIdx[ch][outRow][outCol][1] = maxC;
                this.outputs.set(outRow, outCol, ch, max);
                ++outCol;
            }
            ++outRow;
        }
    }

    @Override
    public void backward() {
        if (this.nextLayer instanceof FullyConnectedLayer) {
            this.backwardFromFullyConnected();
        } else if (this.nextLayer instanceof ConvolutionalLayer) {
            this.backwardFromConvolutional();
        }
    }

    private void backwardFromFullyConnected() {
        this.deltas.fill(0.0f);
        for (int ch = 0; ch < this.deltas.getDepth(); ++ch) {
            this.backwardFromFullyConnectedForChannel(ch);
        }
    }

    private void backwardFromFullyConnectedForChannel(int ch) {
        for (int row = 0; row < this.deltas.getRows(); ++row) {
            for (int col = 0; col < this.deltas.getCols(); ++col) {
                for (int ndC = 0; ndC < this.nextLayer.deltas.getCols(); ++ndC) {
                    float nextLayerDelta = this.nextLayer.deltas.get(ndC);
                    float weight = this.nextLayer.weights.get(col, row, ch, ndC);
                    this.deltas.add(row, col, ch, nextLayerDelta * weight);
                }
            }
        }
    }

    private void backwardFromConvolutional() {
        this.deltas.fill(0.0f);
        for (int ch = 0; ch < this.depth; ++ch) {
            this.backwardFromConvolutionalForChannel(ch);
        }
    }

    private void backwardFromConvolutionalForChannel(int fz) {
        ConvolutionalLayer nextConvLayer = (ConvolutionalLayer)this.nextLayer;
        int filterCenterX = (nextConvLayer.filterWidth - 1) / 2;
        int filterCenterY = (nextConvLayer.filterHeight - 1) / 2;
        for (int ndz = 0; ndz < this.nextLayer.deltas.getDepth(); ++ndz) {
            for (int ndr = 0; ndr < this.nextLayer.deltas.getRows(); ++ndr) {
                for (int ndc = 0; ndc < this.nextLayer.deltas.getCols(); ++ndc) {
                    float nextLayerDelta = this.nextLayer.deltas.get(ndr, ndc, ndz);
                    for (int fr = 0; fr < nextConvLayer.filterHeight; ++fr) {
                        for (int fc = 0; fc < nextConvLayer.filterWidth; ++fc) {
                            int outRow = ndr * nextConvLayer.stride + (fr - filterCenterY);
                            int outCol = ndc * nextConvLayer.stride + (fc - filterCenterX);
                            if (outRow < 0 || outRow >= this.outputs.getRows() || outCol < 0 || outCol >= this.outputs.getCols()) continue;
                            this.deltas.add(outRow, outCol, fz, nextLayerDelta * nextConvLayer.filters[ndz].get(fr, fc, fz));
                        }
                    }
                }
            }
        }
    }

    @Override
    public void applyWeightChanges() {
    }

    public int getFilterWidth() {
        return this.filterWidth;
    }

    public int getFilterHeight() {
        return this.filterHeight;
    }

    public int getStride() {
        return this.stride;
    }

    public String toString() {
        return "Max Pooling Layer { filter width:" + this.filterWidth + ", filter height: " + this.filterHeight + ", stride:" + this.stride + "}";
    }
}

