/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.convolution;

import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.bytedeco.javacpp.cuda;
import org.bytedeco.javacpp.cudnn;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseCudnnHelper;
import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.util.OneTimeLogger;
import org.nd4j.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudnnConvolutionHelper
extends BaseCudnnHelper
implements ConvolutionHelper {
    private static final Logger log = LoggerFactory.getLogger(CudnnConvolutionHelper.class);
    private CudnnConvolutionContext cudnnContext = new CudnnConvolutionContext();

    public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray weights, INDArray delta, int[] kernel, int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn, ConvolutionLayer.AlgoMode mode, ConvolutionLayer.BwdFilterAlgo bwdFilterAlgo, ConvolutionLayer.BwdDataAlgo bwdDataAlgo, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
        int code;
        int[] algo2;
        int[] algo1;
        long inW;
        long inH;
        CudnnForwardArgs args;
        long inDepth;
        long outDepth;
        long miniBatch;
        block30: {
            block29: {
                miniBatch = input.size(0);
                outDepth = weights.size(0);
                inDepth = weights.size(1);
                long kH = weights.size(2);
                long kW = weights.size(3);
                args = CudnnConvolutionHelper.getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null);
                input = args.getInput();
                inH = input.size(2);
                inW = input.size(3);
                long[] srcStride = input.stride();
                int[] outSize = args.getOutSize();
                int outH = outSize[0];
                int outW = outSize[1];
                if (!Shape.strideDescendingCAscendingF((INDArray)delta)) {
                    delta = delta.dup();
                }
                long[] deltaStride = delta.stride();
                algo1 = new int[1];
                algo2 = new int[1];
                if (Nd4j.getExecutioner() instanceof GridExecutioner) {
                    ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
                }
                code = cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)((int)miniBatch), (int)((int)inDepth), (int)((int)inH), (int)((int)inW), (int)((int)srcStride[0]), (int)((int)srcStride[1]), (int)((int)srcStride[2]), (int)((int)srcStride[3]));
                this.checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
                code = cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (int)this.dataType, (int)((int)miniBatch), (int)((int)outDepth), (int)outH, (int)outW, (int)((int)deltaStride[0]), (int)((int)deltaStride[1]), (int)((int)deltaStride[2]), (int)((int)deltaStride[3]));
                this.checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
                code = cudnn.cudnnSetConvolution2dDescriptor((cudnn.cudnnConvolutionStruct)this.cudnnContext.convDesc, (int)pad[0], (int)pad[1], (int)strides[0], (int)strides[1], (int)dilation[0], (int)dilation[1], (int)1, (int)this.dataType);
                this.checkCudnn(false, "cudnnSetConvolution2dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
                code = cudnn.cudnnSetFilter4dDescriptor((cudnn.cudnnFilterStruct)this.cudnnContext.filterDesc, (int)this.dataType, (int)0, (int)((int)outDepth), (int)((int)inDepth), (int)((int)kH), (int)((int)kW));
                this.checkCudnn(false, "cudnnSetFilter4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
                if (mode != ConvolutionLayer.AlgoMode.USER_SPECIFIED || bwdFilterAlgo == null || bwdDataAlgo == null) break block29;
                switch (bwdFilterAlgo) {
                    case ALGO_0: {
                        algo1[0] = 0;
                        break;
                    }
                    case ALGO_1: {
                        algo1[0] = 1;
                        break;
                    }
                    case FFT: {
                        algo1[0] = 2;
                        break;
                    }
                    case ALGO_3: {
                        algo1[0] = 3;
                        break;
                    }
                    case WINOGRAD: {
                        algo1[0] = 4;
                        break;
                    }
                    case WINOGRAD_NONFUSED: {
                        algo1[0] = 5;
                        break;
                    }
                    case FFT_TILING: {
                        algo1[0] = 6;
                        break;
                    }
                    case COUNT: {
                        algo1[0] = 7;
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException("Unknown BwdFilterAlgo: " + bwdFilterAlgo);
                    }
                }
                switch (bwdDataAlgo) {
                    case ALGO_0: {
                        algo2[0] = 0;
                        break block30;
                    }
                    case ALGO_1: {
                        algo2[0] = 1;
                        break block30;
                    }
                    case FFT: {
                        algo2[0] = 2;
                        break block30;
                    }
                    case FFT_TILING: {
                        algo2[0] = 3;
                        break block30;
                    }
                    case WINOGRAD: {
                        algo2[0] = 4;
                        break block30;
                    }
                    case WINOGRAD_NONFUSED: {
                        algo2[0] = 5;
                        break block30;
                    }
                    case COUNT: {
                        algo2[0] = 6;
                        break block30;
                    }
                    default: {
                        throw new IllegalArgumentException("Unknown BwdDataAlgo: " + bwdDataAlgo);
                    }
                }
            }
            code = cudnn.cudnnGetConvolutionBackwardFilterAlgorithm((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (cudnn.cudnnConvolutionStruct)this.cudnnContext.convDesc, (cudnn.cudnnFilterStruct)this.cudnnContext.filterDesc, (int)(mode == ConvolutionLayer.AlgoMode.NO_WORKSPACE ? 0 : 1), (long)0L, (int[])algo1);
            this.checkCudnn(false, "cudnnGetConvolutionBackwardFilterAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
            code = cudnn.cudnnGetConvolutionBackwardDataAlgorithm((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnFilterStruct)this.cudnnContext.filterDesc, (cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (cudnn.cudnnConvolutionStruct)this.cudnnContext.convDesc, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)(mode == ConvolutionLayer.AlgoMode.NO_WORKSPACE ? 0 : 1), (long)0L, (int[])algo2);
            this.checkCudnn(false, "cudnnGetConvolutionBackwardDataAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
        }
        if (log.isTraceEnabled()) {
            ConvolutionLayer.BwdFilterAlgo fa = ConvolutionLayer.BwdFilterAlgo.values()[algo1[0]];
            ConvolutionLayer.BwdDataAlgo da = ConvolutionLayer.BwdDataAlgo.values()[algo2[0]];
            log.trace("CudnnConvolutionHelper backward algorithm selection: mode {}, filter algorithm {}, data algorithm {}", new Object[]{mode, fa, da});
        }
        INDArray epsNext = workspaceMgr.createUninitialized((Enum)ArrayType.ACTIVATION_GRAD, new int[]{(int)miniBatch, (int)inDepth, (int)inH, (int)inW}, 'c');
        long[] dstStride = epsNext.stride();
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareActionAllWrite(new INDArray[]{input, weights, weightGradView, biasGradView, delta, epsNext});
        Pointer srcData = allocator.getPointer(input, context);
        Pointer filterData = allocator.getPointer(weights, context);
        Pointer filterGradData = allocator.getPointer(weightGradView, context);
        Pointer biasGradData = allocator.getPointer(biasGradView, context);
        Pointer deltaData = allocator.getPointer(delta, context);
        Pointer dstData = allocator.getPointer(epsNext, context);
        code = cudnn.cudnnSetStream((cudnn.cudnnContext)this.cudnnContext, (cuda.CUstream_st)new cuda.CUstream_st((Pointer)context.getOldStream()));
        this.checkCudnn(false, "cudnnSetStream", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
        code = cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)((int)miniBatch), (int)((int)inDepth), (int)((int)inH), (int)((int)inW), (int)((int)dstStride[0]), (int)((int)dstStride[1]), (int)((int)dstStride[2]), (int)((int)dstStride[3]));
        this.checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
        code = cudnn.cudnnGetConvolutionBackwardFilterWorkspaceSize((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (cudnn.cudnnConvolutionStruct)this.cudnnContext.convDesc, (cudnn.cudnnFilterStruct)this.cudnnContext.filterDesc, (int)algo1[0], (SizeTPointer)this.sizeInBytes);
        this.checkCudnn(false, "cudnnGetConvolutionBackwardFilterWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
        long sizeInBytes1 = this.sizeInBytes.get(0L);
        code = cudnn.cudnnGetConvolutionBackwardDataWorkspaceSize((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnFilterStruct)this.cudnnContext.filterDesc, (cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (cudnn.cudnnConvolutionStruct)this.cudnnContext.convDesc, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)algo2[0], (SizeTPointer)this.sizeInBytes);
        this.checkCudnn(false, "cudnnGetConvolutionBackwardDataWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
        BaseCudnnHelper.DataCache workSpace = (BaseCudnnHelper.DataCache)workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
        long sizeInBytes2 = this.sizeInBytes.get(0L);
        if (workSpace == null || sizeInBytes1 > workSpace.capacity() || sizeInBytes2 > workSpace.capacity()) {
            long newSize = Math.max(sizeInBytes1, sizeInBytes2);
            if (log.isTraceEnabled()) {
                if (workSpace == null) {
                    log.trace("CudnnConvolutionHelper backpropGradient: Allocating initial workspace of size {} ({})", (Object)newSize, (Object)StringUtils.TraditionalBinaryPrefix.long2String((long)newSize, (String)"B", (int)2));
                } else {
                    log.trace("CudnnConvolutionHelper backpropGradient: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})", new Object[]{workSpace.capacity(), StringUtils.TraditionalBinaryPrefix.long2String((long)workSpace.capacity(), (String)"B", (int)2), newSize, StringUtils.TraditionalBinaryPrefix.long2String((long)newSize, (String)"B", (int)2)});
                }
            }
            if (workSpace != null) {
                workSpace.deallocate();
            }
            workSpace = new BaseCudnnHelper.DataCache(newSize);
            workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, (Pointer)workSpace);
        }
        code = cudnn.cudnnSetTensor4dDescriptor((cudnn.cudnnTensorStruct)this.cudnnContext.biasTensorDesc, (int)0, (int)this.dataType, (int)1, (int)((int)outDepth), (int)1, (int)1);
        this.checkCudnn(false, "cudnnSetTensor4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
        code = cudnn.cudnnConvolutionBackwardBias((cudnn.cudnnContext)this.cudnnContext, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (Pointer)deltaData, (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.biasTensorDesc, (Pointer)biasGradData);
        this.checkCudnn(false, "cudnnConvolutionBackwardBias", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
        code = cudnn.cudnnConvolutionBackwardFilter((cudnn.cudnnContext)this.cudnnContext, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (Pointer)deltaData, (cudnn.cudnnConvolutionStruct)this.cudnnContext.convDesc, (int)algo1[0], (Pointer)workSpace, (long)workSpace.capacity(), (Pointer)this.beta, (cudnn.cudnnFilterStruct)this.cudnnContext.filterDesc, (Pointer)filterGradData);
        this.checkCudnn(false, "cudnnConvolutionBackwardFilter", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
        code = cudnn.cudnnConvolutionBackwardData((cudnn.cudnnContext)this.cudnnContext, (Pointer)this.alpha, (cudnn.cudnnFilterStruct)this.cudnnContext.filterDesc, (Pointer)filterData, (cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (Pointer)deltaData, (cudnn.cudnnConvolutionStruct)this.cudnnContext.convDesc, (int)algo2[0], (Pointer)workSpace, (long)workSpace.capacity(), (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData);
        this.checkCudnn(false, "cudnnConvolutionBackwardData", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
        allocator.getFlowController().registerActionAllWrite(context, new INDArray[]{input, weights, weightGradView, biasGradView, delta, epsNext});
        DefaultGradient retGradient = new DefaultGradient();
        retGradient.setGradientFor("b", biasGradView);
        retGradient.setGradientFor("W", weightGradView, Character.valueOf('c'));
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            context.syncOldStream();
        }
        if (args.isManualPadBottom() || args.isManualPadRight()) {
            epsNext = epsNext.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((long)0L, (long)(epsNext.size(2) - (long)(args.isManualPadBottom() ? 1 : 0))), NDArrayIndex.interval((long)0L, (long)(epsNext.size(3) - (long)(args.isManualPadRight() ? 1 : 0)))});
        }
        return new Pair((Object)retGradient, (Object)epsNext);
    }

    public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad, ConvolutionLayer.AlgoMode mode, ConvolutionLayer.FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
        int[] algo;
        int code;
        INDArray z;
        long outDepth;
        block21: {
            block20: {
                long miniBatch = input.size(0);
                outDepth = weights.size(0);
                long inDepth = weights.size(1);
                long kH = weights.size(2);
                long kW = weights.size(3);
                CudnnForwardArgs args = CudnnConvolutionHelper.getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null);
                input = args.getInput();
                long inH = input.size(2);
                long inW = input.size(3);
                long[] srcStride = input.stride();
                int[] outSize = args.getOutSize();
                if (Nd4j.getExecutioner() instanceof GridExecutioner) {
                    ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
                }
                z = workspaceMgr.createUninitialized((Enum)ArrayType.ACTIVATIONS, new int[]{(int)miniBatch, (int)outDepth, outSize[0], outSize[1]});
                code = cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)((int)miniBatch), (int)((int)inDepth), (int)((int)inH), (int)((int)inW), (int)((int)srcStride[0]), (int)((int)srcStride[1]), (int)((int)srcStride[2]), (int)((int)srcStride[3]));
                this.checkCudnn(true, "cudnnSetTensor4dDescriptorEx", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
                code = cudnn.cudnnSetFilter4dDescriptor((cudnn.cudnnFilterStruct)this.cudnnContext.filterDesc, (int)this.dataType, (int)0, (int)((int)outDepth), (int)((int)inDepth), (int)((int)kH), (int)((int)kW));
                this.checkCudnn(true, "cudnnSetFilter4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
                code = cudnn.cudnnSetConvolution2dDescriptor((cudnn.cudnnConvolutionStruct)this.cudnnContext.convDesc, (int)pad[0], (int)pad[1], (int)strides[0], (int)strides[1], (int)dilation[0], (int)dilation[1], (int)1, (int)this.dataType);
                this.checkCudnn(true, "cudnnSetConvolution2dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
                algo = new int[1];
                long[] dstStride = z.stride();
                code = cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)((int)miniBatch), (int)((int)outDepth), (int)outSize[0], (int)outSize[1], (int)((int)dstStride[0]), (int)((int)dstStride[1]), (int)((int)dstStride[2]), (int)((int)dstStride[3]));
                this.checkCudnn(true, "cudnnSetTensor4dDescriptorEx", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
                if (mode != ConvolutionLayer.AlgoMode.USER_SPECIFIED || fwdAlgo == null) break block20;
                switch (fwdAlgo) {
                    case IMPLICIT_GEMM: {
                        algo[0] = 0;
                        break block21;
                    }
                    case IMPLICIT_PRECOMP_GEMM: {
                        algo[0] = 1;
                        break block21;
                    }
                    case GEMM: {
                        algo[0] = 2;
                        break block21;
                    }
                    case DIRECT: {
                        algo[0] = 3;
                        break block21;
                    }
                    case FFT: {
                        algo[0] = 4;
                        break block21;
                    }
                    case FFT_TILING: {
                        algo[0] = 5;
                        break block21;
                    }
                    case WINOGRAD: {
                        algo[0] = 6;
                        break block21;
                    }
                    case WINOGRAD_NONFUSED: {
                        algo[0] = 7;
                        break block21;
                    }
                    case COUNT: {
                        algo[0] = 8;
                        break block21;
                    }
                    default: {
                        throw new IllegalArgumentException("Unknown FwdAlgo: " + fwdAlgo);
                    }
                }
            }
            code = cudnn.cudnnGetConvolutionForwardAlgorithm((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (cudnn.cudnnFilterStruct)this.cudnnContext.filterDesc, (cudnn.cudnnConvolutionStruct)this.cudnnContext.convDesc, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)(mode == ConvolutionLayer.AlgoMode.NO_WORKSPACE ? 0 : 1), (long)0L, (int[])algo);
            if (code != 0) {
                OneTimeLogger.warn((Logger)log, (String)"Error getting CuDNN forward algorithm - falling back on IMPLICIT_GEMM", (Object[])new Object[0]);
                mode = ConvolutionLayer.AlgoMode.USER_SPECIFIED;
                fwdAlgo = ConvolutionLayer.FwdAlgo.IMPLICIT_GEMM;
                algo[0] = 0;
            }
        }
        if (log.isTraceEnabled()) {
            ConvolutionLayer.FwdAlgo a = ConvolutionLayer.FwdAlgo.values()[algo[0]];
            log.trace("CudnnConvolutionHelper forward algorithm selection: mode {}, algorithm {}", (Object)mode, (Object)a);
        }
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(z, new INDArray[]{input, weights, bias});
        Pointer srcData = allocator.getPointer(input, context);
        Pointer filterData = allocator.getPointer(weights, context);
        Pointer biasData = allocator.getPointer(bias, context);
        Pointer dstData = allocator.getPointer(z, context);
        code = cudnn.cudnnSetStream((cudnn.cudnnContext)this.cudnnContext, (cuda.CUstream_st)new cuda.CUstream_st((Pointer)context.getOldStream()));
        this.checkCudnn(true, "cudnnSetStream", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
        code = cudnn.cudnnGetConvolutionForwardWorkspaceSize((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (cudnn.cudnnFilterStruct)this.cudnnContext.filterDesc, (cudnn.cudnnConvolutionStruct)this.cudnnContext.convDesc, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)algo[0], (SizeTPointer)this.sizeInBytes);
        this.checkCudnn(true, "cudnnGetConvolutionForwardWorkspaceSize", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
        BaseCudnnHelper.DataCache workSpace = (BaseCudnnHelper.DataCache)workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
        if (workSpace == null || this.sizeInBytes.get(0L) > workSpace.capacity()) {
            if (log.isTraceEnabled()) {
                if (workSpace == null) {
                    log.trace("CudnnConvolutionHelper preOutput: allocating initial workspace of size {} ({})", (Object)this.sizeInBytes.get(), (Object)StringUtils.TraditionalBinaryPrefix.long2String((long)this.sizeInBytes.get(), (String)"B", (int)2));
                } else {
                    log.trace("CudnnConvolutionHelper preOutput: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})", new Object[]{workSpace.capacity(), StringUtils.TraditionalBinaryPrefix.long2String((long)workSpace.capacity(), (String)"B", (int)2), this.sizeInBytes.get(), StringUtils.TraditionalBinaryPrefix.long2String((long)this.sizeInBytes.get(), (String)"B", (int)2)});
                }
            }
            if (workSpace != null) {
                workSpace.deallocate();
            }
            workSpace = new BaseCudnnHelper.DataCache(this.sizeInBytes.get(0L));
            workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, (Pointer)workSpace);
        }
        code = cudnn.cudnnConvolutionForward((cudnn.cudnnContext)this.cudnnContext, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (cudnn.cudnnFilterStruct)this.cudnnContext.filterDesc, (Pointer)filterData, (cudnn.cudnnConvolutionStruct)this.cudnnContext.convDesc, (int)algo[0], (Pointer)workSpace, (long)workSpace.capacity(), (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData);
        this.checkCudnn(true, "cudnnConvolutionForward", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
        code = cudnn.cudnnSetTensor4dDescriptor((cudnn.cudnnTensorStruct)this.cudnnContext.biasTensorDesc, (int)0, (int)this.dataType, (int)1, (int)((int)outDepth), (int)1, (int)1);
        this.checkCudnn(true, "cudnnSetTensor4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
        code = cudnn.cudnnAddTensor((cudnn.cudnnContext)this.cudnnContext, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.biasTensorDesc, (Pointer)biasData, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData);
        this.checkCudnn(true, "cudnnAddTensor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
        allocator.registerAction(context, z, new INDArray[]{input, weights, bias});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            context.syncOldStream();
        }
        return z;
    }

    private void checkCudnn(boolean forward, String step, int code, INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel, int[] strides, int[] pad, ConvolutionLayer.AlgoMode mode, ConvolutionLayer.FwdAlgo fwdAlgo, ConvolutionLayer.BwdFilterAlgo bwdFilterAlgo, ConvolutionLayer.BwdDataAlgo bwdDataAlgo, ConvolutionMode convolutionMode, int[] dilation) {
        if (code != 0) {
            StringBuilder sb = new StringBuilder();
            sb.append("CuDNN error = ").append(code).append(": ").append(cudnn.cudnnGetErrorString((int)code).getString()).append(" during ").append(forward ? "forward pass" : "backward pass").append(" - step ").append(step).append(": inputShape=").append(Arrays.toString(input.shape())).append(", weightsShape=").append(Arrays.toString(weights.shape())).append(", biasShape=").append(bias == null ? null : Arrays.toString(bias.shape()));
            if (!forward) {
                sb.append(", gradientShape=").append(Arrays.toString(delta.shape()));
            }
            sb.append(", kernel=").append(Arrays.toString(kernel)).append(", stride=").append(Arrays.toString(strides)).append(", padding=").append(Arrays.toString(pad)).append(", dilation=").append(Arrays.toString(dilation)).append(", AlgoMode=").append(mode);
            if (forward) {
                sb.append(", fwdAlgo=").append(fwdAlgo);
            } else {
                sb.append(", bwdFilterAlgo=").append(bwdFilterAlgo).append(", bwdDataAlgo=").append(bwdDataAlgo);
            }
            sb.append(", convolutionMode=").append(convolutionMode);
            throw new RuntimeException(sb.toString());
        }
    }

    public INDArray activate(INDArray z, IActivation afn) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        INDArray activation = z;
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(z, new INDArray[0]);
        Pointer dstData = allocator.getPointer(z, context);
        CudnnConvolutionHelper.checkCudnn(cudnn.cudnnSetStream((cudnn.cudnnContext)this.cudnnContext, (cuda.CUstream_st)new cuda.CUstream_st((Pointer)context.getOldStream())));
        switch (afn.toString()) {
            case "identity": {
                break;
            }
            case "sigmoid": {
                CudnnConvolutionHelper.checkCudnn(cudnn.cudnnSetActivationDescriptor((cudnn.cudnnActivationStruct)this.cudnnContext.activationDesc, (int)0, (int)1, (double)0.0));
                CudnnConvolutionHelper.checkCudnn(cudnn.cudnnActivationForward((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnActivationStruct)this.cudnnContext.activationDesc, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData, (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData));
                break;
            }
            case "relu": {
                CudnnConvolutionHelper.checkCudnn(cudnn.cudnnSetActivationDescriptor((cudnn.cudnnActivationStruct)this.cudnnContext.activationDesc, (int)1, (int)1, (double)0.0));
                CudnnConvolutionHelper.checkCudnn(cudnn.cudnnActivationForward((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnActivationStruct)this.cudnnContext.activationDesc, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData, (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData));
                break;
            }
            case "tanh": {
                CudnnConvolutionHelper.checkCudnn(cudnn.cudnnSetActivationDescriptor((cudnn.cudnnActivationStruct)this.cudnnContext.activationDesc, (int)2, (int)1, (double)0.0));
                CudnnConvolutionHelper.checkCudnn(cudnn.cudnnActivationForward((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnActivationStruct)this.cudnnContext.activationDesc, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData, (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData));
                break;
            }
            case "softmax": {
                CudnnConvolutionHelper.checkCudnn(cudnn.cudnnSoftmaxForward((cudnn.cudnnContext)this.cudnnContext, (int)1, (int)1, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData, (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData));
                break;
            }
            case "logsoftmax": {
                CudnnConvolutionHelper.checkCudnn(cudnn.cudnnSoftmaxForward((cudnn.cudnnContext)this.cudnnContext, (int)2, (int)1, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData, (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData));
                break;
            }
            default: {
                activation = null;
            }
        }
        allocator.registerAction(context, activation, new INDArray[0]);
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            context.syncOldStream();
        }
        return activation;
    }

    public static CudnnForwardArgs getCudnnForwardArgs(INDArray input, int[] kernel, int[] strides, int[] padding, int[] dilation, ConvolutionMode convolutionMode, PoolingType poolingType) {
        int[] outSize;
        INDArray origInput = input;
        if (input.isView() || !Shape.hasDefaultStridesForShape((INDArray)input)) {
            input = input.dup('c');
        }
        long inH = input.size(2);
        long inW = input.size(3);
        boolean manualPadBottom = false;
        boolean manualPadRight = false;
        if (convolutionMode == ConvolutionMode.Same) {
            int[] padBottomRight;
            outSize = ConvolutionUtils.getOutputSize((INDArray)input, (int[])kernel, (int[])strides, null, (ConvolutionMode)convolutionMode, (int[])dilation);
            padding = ConvolutionUtils.getSameModeTopLeftPadding((int[])outSize, (int[])new int[]{(int)inH, (int)inW}, (int[])kernel, (int[])strides, (int[])dilation);
            if (!Arrays.equals(padding, padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding((int[])outSize, (int[])new int[]{(int)inH, (int)inW}, (int[])kernel, (int[])strides, (int[])dilation))) {
                manualPadBottom = padding[0] != padBottomRight[0];
                manualPadRight = padding[1] != padBottomRight[1];
                long[] newShape = new long[]{input.size(0), input.size(1), input.size(2) + (long)(manualPadBottom ? 1 : 0), input.size(3) + (long)(manualPadRight ? 1 : 0)};
                INDArray newInput = poolingType == null || poolingType != PoolingType.MAX ? Nd4j.create((long[])newShape) : Nd4j.valueArrayOf((long[])newShape, (double)Double.NEGATIVE_INFINITY);
                newInput.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((long)0L, (long)input.size(2)), NDArrayIndex.interval((long)0L, (long)input.size(3))}, input);
                input = newInput;
            }
        } else {
            outSize = ConvolutionUtils.getOutputSize((INDArray)input, (int[])kernel, (int[])strides, (int[])padding, (ConvolutionMode)convolutionMode, (int[])dilation);
        }
        return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);
    }

    public Map<String, Long> helperMemoryUse() {
        return Collections.emptyMap();
    }

    public static class CudnnForwardArgs {
        private boolean manualPadBottom;
        private boolean manualPadRight;
        private INDArray input;
        private INDArray origInput;
        private int[] padding;
        private int[] outSize;

        public CudnnForwardArgs(boolean manualPadBottom, boolean manualPadRight, INDArray input, INDArray origInput, int[] padding, int[] outSize) {
            this.manualPadBottom = manualPadBottom;
            this.manualPadRight = manualPadRight;
            this.input = input;
            this.origInput = origInput;
            this.padding = padding;
            this.outSize = outSize;
        }

        public boolean isManualPadBottom() {
            return this.manualPadBottom;
        }

        public boolean isManualPadRight() {
            return this.manualPadRight;
        }

        public INDArray getInput() {
            return this.input;
        }

        public INDArray getOrigInput() {
            return this.origInput;
        }

        public int[] getPadding() {
            return this.padding;
        }

        public int[] getOutSize() {
            return this.outSize;
        }

        public void setManualPadBottom(boolean manualPadBottom) {
            this.manualPadBottom = manualPadBottom;
        }

        public void setManualPadRight(boolean manualPadRight) {
            this.manualPadRight = manualPadRight;
        }

        public void setInput(INDArray input) {
            this.input = input;
        }

        public void setOrigInput(INDArray origInput) {
            this.origInput = origInput;
        }

        public void setPadding(int[] padding) {
            this.padding = padding;
        }

        public void setOutSize(int[] outSize) {
            this.outSize = outSize;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof CudnnForwardArgs)) {
                return false;
            }
            CudnnForwardArgs other = (CudnnForwardArgs)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.isManualPadBottom() != other.isManualPadBottom()) {
                return false;
            }
            if (this.isManualPadRight() != other.isManualPadRight()) {
                return false;
            }
            INDArray this$input = this.getInput();
            INDArray other$input = other.getInput();
            if (this$input == null ? other$input != null : !this$input.equals(other$input)) {
                return false;
            }
            INDArray this$origInput = this.getOrigInput();
            INDArray other$origInput = other.getOrigInput();
            if (this$origInput == null ? other$origInput != null : !this$origInput.equals(other$origInput)) {
                return false;
            }
            if (!Arrays.equals(this.getPadding(), other.getPadding())) {
                return false;
            }
            return Arrays.equals(this.getOutSize(), other.getOutSize());
        }

        protected boolean canEqual(Object other) {
            return other instanceof CudnnForwardArgs;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + (this.isManualPadBottom() ? 79 : 97);
            result = result * 59 + (this.isManualPadRight() ? 79 : 97);
            INDArray $input = this.getInput();
            result = result * 59 + ($input == null ? 43 : $input.hashCode());
            INDArray $origInput = this.getOrigInput();
            result = result * 59 + ($origInput == null ? 43 : $origInput.hashCode());
            result = result * 59 + Arrays.hashCode(this.getPadding());
            result = result * 59 + Arrays.hashCode(this.getOutSize());
            return result;
        }

        public String toString() {
            return "CudnnConvolutionHelper.CudnnForwardArgs(manualPadBottom=" + this.isManualPadBottom() + ", manualPadRight=" + this.isManualPadRight() + ", input=" + this.getInput() + ", origInput=" + this.getOrigInput() + ", padding=" + Arrays.toString(this.getPadding()) + ", outSize=" + Arrays.toString(this.getOutSize()) + ")";
        }
    }

    private static class CudnnConvolutionContext
    extends BaseCudnnHelper.CudnnContext {
        private cudnn.cudnnTensorStruct srcTensorDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnTensorStruct dstTensorDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnTensorStruct biasTensorDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnTensorStruct deltaTensorDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnFilterStruct filterDesc = new cudnn.cudnnFilterStruct();
        private cudnn.cudnnConvolutionStruct convDesc = new cudnn.cudnnConvolutionStruct();
        private cudnn.cudnnActivationStruct activationDesc = new cudnn.cudnnActivationStruct();

        public CudnnConvolutionContext() {
            this.createHandles();
            this.deallocator(new Deallocator(this));
        }

        public CudnnConvolutionContext(CudnnConvolutionContext c) {
            super(c);
            this.srcTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.srcTensorDesc);
            this.dstTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.dstTensorDesc);
            this.biasTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.biasTensorDesc);
            this.deltaTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.deltaTensorDesc);
            this.filterDesc = new cudnn.cudnnFilterStruct((Pointer)c.filterDesc);
            this.convDesc = new cudnn.cudnnConvolutionStruct((Pointer)c.convDesc);
            this.activationDesc = new cudnn.cudnnActivationStruct((Pointer)c.activationDesc);
        }

        @Override
        protected void createHandles() {
            super.createHandles();
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.srcTensorDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.dstTensorDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.biasTensorDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.deltaTensorDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnCreateFilterDescriptor((cudnn.cudnnFilterStruct)this.filterDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnCreateConvolutionDescriptor((cudnn.cudnnConvolutionStruct)this.convDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnCreateActivationDescriptor((cudnn.cudnnActivationStruct)this.activationDesc));
        }

        @Override
        protected void destroyHandles() {
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnDestroyActivationDescriptor((cudnn.cudnnActivationStruct)this.activationDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnDestroyConvolutionDescriptor((cudnn.cudnnConvolutionStruct)this.convDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnDestroyFilterDescriptor((cudnn.cudnnFilterStruct)this.filterDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.srcTensorDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.dstTensorDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.biasTensorDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.deltaTensorDesc));
            super.destroyHandles();
        }

        private static class Deallocator
        extends CudnnConvolutionContext
        implements Pointer.Deallocator {
            Deallocator(CudnnConvolutionContext c) {
                super(c);
            }

            public void deallocate() {
                this.destroyHandles();
            }
        }
    }
}

