/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.convolution.subsampling;

import java.util.Collections;
import java.util.Map;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.cuda;
import org.bytedeco.javacpp.cudnn;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseCudnnHelper;
import org.deeplearning4j.nn.layers.convolution.CudnnConvolutionHelper;
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudnnSubsamplingHelper
extends BaseCudnnHelper
implements SubsamplingHelper {
    private static final Logger log = LoggerFactory.getLogger(CudnnSubsamplingHelper.class);
    private CudnnSubsamplingContext cudnnContext = new CudnnSubsamplingContext();

    public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides, int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
        int poolingMode;
        if (dilation[0] != 1 || dilation[1] != 1) {
            return null;
        }
        INDArray reduced = this.activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, workspaceMgr);
        long miniBatch = input.size(0);
        long depth = input.size(1);
        CudnnConvolutionHelper.CudnnForwardArgs args = CudnnConvolutionHelper.getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType);
        input = args.getInput();
        long inH = input.size(2);
        long inW = input.size(3);
        long[] srcStride = input.stride();
        int[] outSize = args.getOutSize();
        int outH = outSize[0];
        int outW = outSize[1];
        DefaultGradient retGradient = new DefaultGradient();
        switch (poolingType) {
            case AVG: {
                poolingMode = 1;
                break;
            }
            case MAX: {
                poolingMode = 0;
                break;
            }
            default: {
                return null;
            }
        }
        if (!Shape.hasDefaultStridesForShape((INDArray)epsilon) || epsilon.isView()) {
            epsilon = epsilon.dup('c');
        }
        long[] deltaStride = epsilon.stride();
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)((int)miniBatch), (int)((int)depth), (int)((int)inH), (int)((int)inW), (int)((int)srcStride[0]), (int)((int)srcStride[1]), (int)((int)srcStride[2]), (int)((int)srcStride[3])));
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (int)this.dataType, (int)((int)miniBatch), (int)((int)depth), (int)outH, (int)outW, (int)((int)deltaStride[0]), (int)((int)deltaStride[1]), (int)((int)deltaStride[2]), (int)((int)deltaStride[3])));
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetPooling2dDescriptor((cudnn.cudnnPoolingStruct)this.cudnnContext.poolingDesc, (int)poolingMode, (int)1, (int)kernel[0], (int)kernel[1], (int)pad[0], (int)pad[1], (int)strides[0], (int)strides[1]));
        INDArray outEpsilon = workspaceMgr.createUninitialized((Enum)ArrayType.ACTIVATION_GRAD, new int[]{(int)miniBatch, (int)depth, (int)inH, (int)inW}, 'c');
        long[] dstStride = outEpsilon.stride();
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)((int)miniBatch), (int)((int)depth), (int)((int)inH), (int)((int)inW), (int)((int)dstStride[0]), (int)((int)dstStride[1]), (int)((int)dstStride[2]), (int)((int)dstStride[3])));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(input, new INDArray[]{epsilon, reduced, outEpsilon});
        Pointer srcData = allocator.getPointer(input, context);
        Pointer epsData = allocator.getPointer(epsilon, context);
        Pointer zData = allocator.getPointer(reduced, context);
        Pointer dstData = allocator.getPointer(outEpsilon, context);
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetStream((cudnn.cudnnContext)this.cudnnContext, (cuda.CUstream_st)new cuda.CUstream_st((Pointer)context.getOldStream())));
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnPoolingBackward((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnPoolingStruct)this.cudnnContext.poolingDesc, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (Pointer)zData, (cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (Pointer)epsData, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData));
        allocator.registerAction(context, outEpsilon, new INDArray[]{input, epsilon, reduced});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            context.syncOldStream();
        }
        if (args.isManualPadBottom() || args.isManualPadRight()) {
            outEpsilon = outEpsilon.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((long)0L, (long)(outEpsilon.size(2) - (long)(args.isManualPadBottom() ? 1 : 0))), NDArrayIndex.interval((long)0L, (long)(outEpsilon.size(3) - (long)(args.isManualPadRight() ? 1 : 0)))});
        }
        return new Pair((Object)retGradient, (Object)outEpsilon);
    }

    public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
        int poolingMode;
        if (dilation[0] != 1 || dilation[1] != 1) {
            return null;
        }
        long miniBatch = input.size(0);
        long inDepth = input.size(1);
        CudnnConvolutionHelper.CudnnForwardArgs args = CudnnConvolutionHelper.getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType);
        input = args.getInput();
        long inH = input.size(2);
        long inW = input.size(3);
        long[] srcStride = input.stride();
        int[] outSize = args.getOutSize();
        int outH = outSize[0];
        int outW = outSize[1];
        switch (poolingType) {
            case AVG: {
                poolingMode = 1;
                break;
            }
            case MAX: {
                poolingMode = 0;
                break;
            }
            default: {
                return null;
            }
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetPooling2dDescriptor((cudnn.cudnnPoolingStruct)this.cudnnContext.poolingDesc, (int)poolingMode, (int)1, (int)kernel[0], (int)kernel[1], (int)pad[0], (int)pad[1], (int)strides[0], (int)strides[1]));
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)((int)miniBatch), (int)((int)inDepth), (int)((int)inH), (int)((int)inW), (int)((int)srcStride[0]), (int)((int)srcStride[1]), (int)((int)srcStride[2]), (int)((int)srcStride[3])));
        INDArray reduced = workspaceMgr.createUninitialized((Enum)ArrayType.ACTIVATIONS, new int[]{(int)miniBatch, (int)inDepth, outH, outW}, 'c');
        long[] dstStride = reduced.stride();
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)((int)miniBatch), (int)((int)inDepth), (int)outH, (int)outW, (int)((int)dstStride[0]), (int)((int)dstStride[1]), (int)((int)dstStride[2]), (int)((int)dstStride[3])));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(input, new INDArray[]{reduced});
        Pointer srcData = allocator.getPointer(input, context);
        Pointer dstData = allocator.getPointer(reduced, context);
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetStream((cudnn.cudnnContext)this.cudnnContext, (cuda.CUstream_st)new cuda.CUstream_st((Pointer)context.getOldStream())));
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnPoolingForward((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnPoolingStruct)this.cudnnContext.poolingDesc, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData));
        allocator.registerAction(context, reduced, new INDArray[]{input});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            context.syncOldStream();
        }
        return reduced;
    }

    public Map<String, Long> helperMemoryUse() {
        return Collections.emptyMap();
    }

    private static class CudnnSubsamplingContext
    extends BaseCudnnHelper.CudnnContext {
        private cudnn.cudnnTensorStruct srcTensorDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnTensorStruct dstTensorDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnTensorStruct deltaTensorDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnPoolingStruct poolingDesc = new cudnn.cudnnPoolingStruct();

        public CudnnSubsamplingContext() {
            this.createHandles();
            this.deallocator(new Deallocator(this));
        }

        public CudnnSubsamplingContext(CudnnSubsamplingContext c) {
            super(c);
            this.srcTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.srcTensorDesc);
            this.dstTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.dstTensorDesc);
            this.deltaTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.deltaTensorDesc);
            this.poolingDesc = new cudnn.cudnnPoolingStruct((Pointer)c.poolingDesc);
        }

        @Override
        protected void createHandles() {
            super.createHandles();
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.srcTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.dstTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.deltaTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreatePoolingDescriptor((cudnn.cudnnPoolingStruct)this.poolingDesc));
        }

        @Override
        protected void destroyHandles() {
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroyPoolingDescriptor((cudnn.cudnnPoolingStruct)this.poolingDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.srcTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.dstTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.deltaTensorDesc));
            super.destroyHandles();
        }

        private static class Deallocator
        extends CudnnSubsamplingContext
        implements Pointer.Deallocator {
            Deallocator(CudnnSubsamplingContext c) {
                super(c);
            }

            public void deallocate() {
                this.destroyHandles();
            }
        }
    }
}

