/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.recurrent;

import java.util.HashMap;
import java.util.Map;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.bytedeco.javacpp.cuda;
import org.bytedeco.javacpp.cudnn;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseCudnnHelper;
import org.deeplearning4j.nn.layers.recurrent.FwdPassReturn;
import org.deeplearning4j.nn.layers.recurrent.LSTMHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.activations.impl.ActivationTanH;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudnnLSTMHelper
extends BaseCudnnHelper
implements LSTMHelper {
    private static final Logger log = LoggerFactory.getLogger(CudnnLSTMHelper.class);
    protected static final int NUM_LAYERS = 1;
    protected static final float DROPOUT = 0.0f;
    protected static final boolean BIDIRECTIONAL = false;
    protected static final int RNN_MODE = 2;
    protected static final int NUM_LINEAR_LAYERS = 8;
    private CudnnLSTMContext cudnnContext = new CudnnLSTMContext();
    private BaseCudnnHelper.TensorArray xDesc = new BaseCudnnHelper.TensorArray();
    private BaseCudnnHelper.TensorArray yDesc = new BaseCudnnHelper.TensorArray();
    private BaseCudnnHelper.TensorArray dxDesc = new BaseCudnnHelper.TensorArray();
    private BaseCudnnHelper.TensorArray dyDesc = new BaseCudnnHelper.TensorArray();
    private BaseCudnnHelper.DataCache stateSpace = new BaseCudnnHelper.DataCache();
    private BaseCudnnHelper.DataCache reserveSpace = new BaseCudnnHelper.DataCache();
    private BaseCudnnHelper.DataCache weightsSpace = new BaseCudnnHelper.DataCache();
    private boolean initializedDropoutDescriptor = false;

    private static INDArray toCOrder(INDArray arr) {
        if (arr.isView() || arr.ordering() != 'c' || !Shape.strideDescendingCAscendingF((INDArray)arr)) {
            arr = arr.dup('c');
        }
        return arr;
    }

    public boolean checkSupported(IActivation gateActivationFn, IActivation activationFn, boolean hasPeepholeConnections) {
        boolean supported = this.checkSupported();
        if (!(gateActivationFn instanceof ActivationSigmoid)) {
            supported = false;
            log.warn("Not supported: Gate activation functions != ActivationSigmoid");
        }
        if (!(activationFn instanceof ActivationTanH)) {
            supported = false;
            log.warn("Not supported: Layer activation functions != ActivationTanH");
        }
        if (hasPeepholeConnections) {
            supported = false;
            log.warn("Not supported: LSTM layers with peephole connections");
        }
        return supported;
    }

    public Pair<Gradient, INDArray> backpropGradient(NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input, INDArray recurrentWeights, INDArray inputWeights, INDArray epsilon, boolean truncatedBPTT, int tbpttBackwardLength, FwdPassReturn fwdPass, boolean forwards, String inputWeightKey, String recurrentWeightKey, String biasWeightKey, Map<String, INDArray> gradientViews, INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) {
        long hiddenLayerSize = recurrentWeights.size(0);
        long prevLayerSize = inputWeights.size(0);
        long inputLayerSize = input.size(1);
        long miniBatchSize = epsilon.size(0);
        boolean is2dInput = epsilon.rank() < 3;
        long timeSeriesLength = is2dInput ? 1L : epsilon.size(2);
        INDArray x = CudnnLSTMHelper.toCOrder(input.permute(new int[]{2, 0, 1}));
        INDArray dy = CudnnLSTMHelper.toCOrder(epsilon.permute(new int[]{2, 0, 1}));
        INDArray dx = workspaceMgr.createUninitialized((Enum)ArrayType.ACTIVATION_GRAD, new long[]{timeSeriesLength, miniBatchSize, prevLayerSize}, 'c');
        INDArray iwGradientsOut = gradientViews.get(inputWeightKey);
        INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey);
        INDArray bGradientsOut = gradientViews.get(biasWeightKey);
        INDArray outputActivations = CudnnLSTMHelper.toCOrder(fwdPass.fwdPassOutput.permute(new int[]{2, 0, 1}));
        INDArray prevStepMemCellState = CudnnLSTMHelper.toCOrder(fwdPass.prevMemCell);
        INDArray prevStepActivations = CudnnLSTMHelper.toCOrder(fwdPass.prevAct);
        Nd4j.getExecutioner().commit();
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareActionAllWrite(new INDArray[]{x, dy, dx, outputActivations, prevStepMemCellState, prevStepActivations, iwGradientsOut, rwGradientsOut, bGradientsOut});
        Pointer xData = allocator.getPointer(x, context);
        Pointer dyData = allocator.getPointer(dy, context);
        Pointer dxData = allocator.getPointer(dx, context);
        Pointer outputActivationsData = allocator.getPointer(outputActivations, context);
        Pointer prevMemCellStateData = allocator.getPointer(prevStepMemCellState, context);
        Pointer prevStepActivationsData = allocator.getPointer(prevStepActivations, context);
        Pointer iwGradientsOutData = allocator.getPointer(iwGradientsOut, context);
        Pointer rwGradientsOutData = allocator.getPointer(rwGradientsOut, context);
        Pointer bGradientsOutData = allocator.getPointer(bGradientsOut, context);
        cuda.CUstream_st stream = new cuda.CUstream_st((Pointer)context.getOldStream());
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetStream((cudnn.cudnnContext)this.cudnnContext, (cuda.CUstream_st)stream));
        if (truncatedBPTT) {
            long endIdx = Math.max(0L, timeSeriesLength - (long)tbpttBackwardLength) * miniBatchSize * hiddenLayerSize;
            xData.position(endIdx * (long)this.dataTypeSize);
            dyData.position(endIdx * 1L * (long)this.dataTypeSize);
            outputActivationsData.position(endIdx * 1L * (long)this.dataTypeSize);
            timeSeriesLength = (int)Math.min(timeSeriesLength, (long)tbpttBackwardLength);
        }
        cudnn.cudnnTensorStruct xDesc0 = (cudnn.cudnnTensorStruct)this.xDesc.get(cudnn.cudnnTensorStruct.class, 0L);
        BaseCudnnHelper.DataCache workSpace = (BaseCudnnHelper.DataCache)workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnRNNBackwardData((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnRNNStruct)this.cudnnContext.rnnDesc, (int)((int)timeSeriesLength), (PointerPointer)this.yDesc, (Pointer)outputActivationsData, (PointerPointer)this.dyDesc, (Pointer)dyData, (cudnn.cudnnTensorStruct)this.cudnnContext.dhyDesc, null, (cudnn.cudnnTensorStruct)this.cudnnContext.dcyDesc, null, (cudnn.cudnnFilterStruct)this.cudnnContext.wDesc, (Pointer)this.weightsSpace, (cudnn.cudnnTensorStruct)this.cudnnContext.hxDesc, (Pointer)prevStepActivationsData, (cudnn.cudnnTensorStruct)this.cudnnContext.cxDesc, (Pointer)prevMemCellStateData, (PointerPointer)this.dxDesc, (Pointer)dxData, (cudnn.cudnnTensorStruct)this.cudnnContext.dhxDesc, null, (cudnn.cudnnTensorStruct)this.cudnnContext.dcxDesc, null, (Pointer)workSpace, (long)workSpace.limit(), (Pointer)this.reserveSpace, (long)this.reserveSpace.limit()));
        CudnnLSTMHelper.checkCuda(cuda.cudaMemsetAsync((Pointer)this.weightsSpace, (int)0, (long)this.weightsSpace.limit(), (cuda.CUstream_st)stream));
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnRNNBackwardWeights((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnRNNStruct)this.cudnnContext.rnnDesc, (int)((int)timeSeriesLength), (PointerPointer)this.xDesc, (Pointer)xData, (cudnn.cudnnTensorStruct)this.cudnnContext.hxDesc, (Pointer)prevStepActivationsData, (PointerPointer)this.yDesc, (Pointer)outputActivationsData, (Pointer)workSpace, (long)workSpace.limit(), (cudnn.cudnnFilterStruct)this.cudnnContext.dwDesc, (Pointer)this.weightsSpace, (Pointer)this.reserveSpace, (long)this.reserveSpace.limit()));
        int[] dataType = new int[1];
        int[] format = new int[1];
        int[] nbDims = new int[1];
        int[] filterDimA = new int[3];
        Pointer linLayerMat = new Pointer();
        Pointer linLayerBias = new Pointer();
        for (int layer = 0; layer < 1; ++layer) {
            for (int linLayerID = 0; linLayerID < 8; ++linLayerID) {
                CudnnLSTMHelper.checkCudnn(cudnn.cudnnGetRNNLinLayerMatrixParams((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnRNNStruct)this.cudnnContext.rnnDesc, (int)layer, (cudnn.cudnnTensorStruct)xDesc0, (cudnn.cudnnFilterStruct)this.cudnnContext.wDesc, (Pointer)this.weightsSpace, (int)linLayerID, (cudnn.cudnnFilterStruct)this.cudnnContext.linLayerMatDesc, (Pointer)linLayerMat));
                CudnnLSTMHelper.checkCudnn(cudnn.cudnnGetFilterNdDescriptor((cudnn.cudnnFilterStruct)this.cudnnContext.linLayerMatDesc, (int)3, (int[])dataType, (int[])format, (int[])nbDims, (int[])filterDimA));
                CudnnLSTMHelper.checkCudnn(cudnn.cudnnGetRNNLinLayerBiasParams((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnRNNStruct)this.cudnnContext.rnnDesc, (int)layer, (cudnn.cudnnTensorStruct)xDesc0, (cudnn.cudnnFilterStruct)this.cudnnContext.wDesc, (Pointer)this.weightsSpace, (int)linLayerID, (cudnn.cudnnFilterStruct)this.cudnnContext.linLayerBiasDesc, (Pointer)linLayerBias));
                CudnnLSTMHelper.checkCudnn(cudnn.cudnnGetFilterNdDescriptor((cudnn.cudnnFilterStruct)this.cudnnContext.linLayerBiasDesc, (int)3, (int[])dataType, (int[])format, (int[])nbDims, (int[])filterDimA));
                int position = 0;
                long size = 0L;
                Pointer data = null;
                switch (linLayerID) {
                    case 0: {
                        data = iwGradientsOutData;
                        position = 3;
                        size = inputLayerSize;
                        break;
                    }
                    case 1: {
                        data = iwGradientsOutData;
                        position = 1;
                        size = inputLayerSize;
                        break;
                    }
                    case 2: {
                        data = iwGradientsOutData;
                        position = 0;
                        size = inputLayerSize;
                        break;
                    }
                    case 3: {
                        data = iwGradientsOutData;
                        position = 2;
                        size = inputLayerSize;
                        break;
                    }
                    case 4: {
                        data = rwGradientsOutData;
                        position = 3;
                        size = hiddenLayerSize;
                        break;
                    }
                    case 5: {
                        data = rwGradientsOutData;
                        position = 1;
                        size = hiddenLayerSize;
                        break;
                    }
                    case 6: {
                        data = rwGradientsOutData;
                        position = 0;
                        size = hiddenLayerSize;
                        break;
                    }
                    case 7: {
                        data = rwGradientsOutData;
                        position = 2;
                        size = hiddenLayerSize;
                        break;
                    }
                    default: {
                        throw new RuntimeException();
                    }
                }
                CudnnLSTMHelper.checkCuda(cuda.cudaMemcpyAsync((Pointer)data.position((long)position * size * hiddenLayerSize * (long)this.dataTypeSize), (Pointer)linLayerMat, (long)(size * hiddenLayerSize * (long)this.dataTypeSize), (int)3, (cuda.CUstream_st)stream));
                if (linLayerID >= 4) continue;
                CudnnLSTMHelper.checkCuda(cuda.cudaMemcpyAsync((Pointer)bGradientsOutData.position((long)position * hiddenLayerSize * (long)this.dataTypeSize), (Pointer)linLayerBias, (long)(hiddenLayerSize * (long)this.dataTypeSize), (int)3, (cuda.CUstream_st)stream));
            }
        }
        allocator.getFlowController().registerActionAllWrite(context, new INDArray[]{x, dy, dx, outputActivations, prevStepMemCellState, prevStepActivations, iwGradientsOut, rwGradientsOut, bGradientsOut});
        DefaultGradient retGradient = new DefaultGradient();
        retGradient.gradientForVariable().put(inputWeightKey, iwGradientsOut);
        retGradient.gradientForVariable().put(recurrentWeightKey, rwGradientsOut);
        retGradient.gradientForVariable().put(biasWeightKey, bGradientsOut);
        INDArray epsilonNext = dx.permute(new int[]{1, 2, 0});
        return new Pair((Object)retGradient, (Object)epsilonNext);
    }

    public FwdPassReturn activate(Layer layer, NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input, INDArray recurrentWeights, INDArray inputWeights, INDArray biases, boolean training, INDArray prevOutputActivations, INDArray prevMemCellState, boolean forBackprop, boolean forwards, String inputWeightKey, INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) {
        boolean is2dInput = input.rank() < 3;
        long timeSeriesLength = is2dInput ? 1L : input.size(2);
        long hiddenLayerSize = recurrentWeights.size(0);
        long miniBatchSize = input.size(0);
        long inputLayerSize = input.size(1);
        INDArray x = CudnnLSTMHelper.toCOrder(input.permute(new int[]{2, 0, 1}));
        INDArray linInputWeights = inputWeights;
        INDArray linRecurrentWeights = recurrentWeights;
        INDArray linBiases = biases;
        INDArray prevAct = CudnnLSTMHelper.toCOrder(prevOutputActivations);
        INDArray prevMemCell = CudnnLSTMHelper.toCOrder(prevMemCellState);
        INDArray outputActivations = workspaceMgr.createUninitialized((Enum)ArrayType.ACTIVATIONS, new long[]{timeSeriesLength, miniBatchSize, hiddenLayerSize * 1L}, 'c');
        INDArray finalMemCellState = Nd4j.createUninitialized((long[])new long[]{miniBatchSize, hiddenLayerSize}, (char)'c');
        INDArray finalStepActivations = Nd4j.createUninitialized((long[])new long[]{miniBatchSize, hiddenLayerSize}, (char)'c');
        FwdPassReturn toReturn = new FwdPassReturn();
        toReturn.prevAct = prevAct;
        toReturn.prevMemCell = prevMemCell;
        Nd4j.getExecutioner().commit();
        if (timeSeriesLength > this.xDesc.capacity()) {
            this.xDesc.deallocate();
            this.xDesc = new BaseCudnnHelper.TensorArray(timeSeriesLength);
        }
        if (timeSeriesLength > this.yDesc.capacity()) {
            this.yDesc.deallocate();
            this.yDesc = new BaseCudnnHelper.TensorArray(timeSeriesLength);
        }
        if (timeSeriesLength > this.dxDesc.capacity()) {
            this.dxDesc.deallocate();
            this.dxDesc = new BaseCudnnHelper.TensorArray(timeSeriesLength);
        }
        if (timeSeriesLength > this.dyDesc.capacity()) {
            this.dyDesc.deallocate();
            this.dyDesc = new BaseCudnnHelper.TensorArray(timeSeriesLength);
        }
        int i = 0;
        while ((long)i < timeSeriesLength) {
            int[] dimA = new int[]{(int)miniBatchSize, (int)inputLayerSize, 1};
            int[] strideA = new int[]{dimA[2] * dimA[1], dimA[2], 1};
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetTensorNdDescriptor((cudnn.cudnnTensorStruct)((cudnn.cudnnTensorStruct)this.xDesc.get(cudnn.cudnnTensorStruct.class, i)), (int)this.dataType, (int)3, (int[])dimA, (int[])strideA));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetTensorNdDescriptor((cudnn.cudnnTensorStruct)((cudnn.cudnnTensorStruct)this.dxDesc.get(cudnn.cudnnTensorStruct.class, i)), (int)this.dataType, (int)3, (int[])dimA, (int[])strideA));
            int[] dimB = new int[]{(int)miniBatchSize, (int)hiddenLayerSize * 1, 1};
            int[] strideB = new int[]{dimB[2] * dimB[1], dimB[2], 1};
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetTensorNdDescriptor((cudnn.cudnnTensorStruct)((cudnn.cudnnTensorStruct)this.yDesc.get(cudnn.cudnnTensorStruct.class, i)), (int)this.dataType, (int)3, (int[])dimB, (int[])strideB));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetTensorNdDescriptor((cudnn.cudnnTensorStruct)((cudnn.cudnnTensorStruct)this.dyDesc.get(cudnn.cudnnTensorStruct.class, i)), (int)this.dataType, (int)3, (int[])dimB, (int[])strideB));
            ++i;
        }
        int[] dimC = new int[]{1, (int)miniBatchSize, (int)hiddenLayerSize};
        int[] strideC = new int[]{dimC[2] * dimC[1], dimC[2], 1};
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetTensorNdDescriptor((cudnn.cudnnTensorStruct)this.cudnnContext.hxDesc, (int)this.dataType, (int)3, (int[])dimC, (int[])strideC));
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetTensorNdDescriptor((cudnn.cudnnTensorStruct)this.cudnnContext.cxDesc, (int)this.dataType, (int)3, (int[])dimC, (int[])strideC));
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetTensorNdDescriptor((cudnn.cudnnTensorStruct)this.cudnnContext.hyDesc, (int)this.dataType, (int)3, (int[])dimC, (int[])strideC));
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetTensorNdDescriptor((cudnn.cudnnTensorStruct)this.cudnnContext.cyDesc, (int)this.dataType, (int)3, (int[])dimC, (int[])strideC));
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetTensorNdDescriptor((cudnn.cudnnTensorStruct)this.cudnnContext.dhxDesc, (int)this.dataType, (int)3, (int[])dimC, (int[])strideC));
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetTensorNdDescriptor((cudnn.cudnnTensorStruct)this.cudnnContext.dcxDesc, (int)this.dataType, (int)3, (int[])dimC, (int[])strideC));
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetTensorNdDescriptor((cudnn.cudnnTensorStruct)this.cudnnContext.dhyDesc, (int)this.dataType, (int)3, (int[])dimC, (int[])strideC));
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetTensorNdDescriptor((cudnn.cudnnTensorStruct)this.cudnnContext.dcyDesc, (int)this.dataType, (int)3, (int[])dimC, (int[])strideC));
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnDropoutGetStatesSize((cudnn.cudnnContext)this.cudnnContext, (SizeTPointer)this.sizeInBytes));
        long stateSize = this.sizeInBytes.get(0L);
        if (stateSize > this.stateSpace.capacity()) {
            this.stateSpace.deallocate();
            this.stateSpace = new BaseCudnnHelper.DataCache(stateSize);
        }
        this.stateSpace.limit(stateSize);
        if (!this.initializedDropoutDescriptor) {
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetDropoutDescriptor((cudnn.cudnnDropoutStruct)this.cudnnContext.dropoutDesc, (cudnn.cudnnContext)this.cudnnContext, (float)0.0f, (Pointer)this.stateSpace, (long)stateSize, (long)Nd4j.getRandom().getSeed()));
        }
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetRNNDescriptor_v6((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnRNNStruct)this.cudnnContext.rnnDesc, (int)((int)hiddenLayerSize), (int)1, (cudnn.cudnnDropoutStruct)this.cudnnContext.dropoutDesc, (int)0, (int)0, (int)2, (int)0, (int)this.dataType));
        cudnn.cudnnTensorStruct xDesc0 = (cudnn.cudnnTensorStruct)this.xDesc.get(cudnn.cudnnTensorStruct.class, 0L);
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnGetRNNParamsSize((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnRNNStruct)this.cudnnContext.rnnDesc, (cudnn.cudnnTensorStruct)xDesc0, (SizeTPointer)this.sizeInBytes, (int)this.dataType));
        long weightsSize = this.sizeInBytes.get(0L);
        if (weightsSize > this.weightsSpace.capacity()) {
            this.weightsSpace.deallocate();
            this.weightsSpace = new BaseCudnnHelper.DataCache(weightsSize);
        }
        this.weightsSpace.limit(weightsSize);
        int[] dimW = new int[]{(int)weightsSize / this.dataTypeSize, 1, 1};
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetFilterNdDescriptor((cudnn.cudnnFilterStruct)this.cudnnContext.wDesc, (int)this.dataType, (int)0, (int)3, (int[])dimW));
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetFilterNdDescriptor((cudnn.cudnnFilterStruct)this.cudnnContext.dwDesc, (int)this.dataType, (int)0, (int)3, (int[])dimW));
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnGetRNNWorkspaceSize((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnRNNStruct)this.cudnnContext.rnnDesc, (int)((int)timeSeriesLength), (PointerPointer)this.xDesc, (SizeTPointer)this.sizeInBytes));
        long workSize = this.sizeInBytes.get(0L);
        BaseCudnnHelper.DataCache workSpace = (BaseCudnnHelper.DataCache)workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
        if (workSpace == null || workSize > workSpace.capacity()) {
            if (log.isTraceEnabled()) {
                if (workSpace == null) {
                    log.trace("CudnnLSTMHelper activate: Allocating initial workspace of size {} ({})", (Object)workSize, (Object)StringUtils.TraditionalBinaryPrefix.long2String((long)workSize, (String)"B", (int)2));
                } else {
                    log.trace("CudnnLSTMHelper activate: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})", new Object[]{workSpace.capacity(), StringUtils.TraditionalBinaryPrefix.long2String((long)workSpace.capacity(), (String)"B", (int)2), workSize, StringUtils.TraditionalBinaryPrefix.long2String((long)workSize, (String)"B", (int)2)});
                }
            }
            if (workSpace != null) {
                workSpace.deallocate();
            }
            workSpace = new BaseCudnnHelper.DataCache(workSize);
            workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, (Pointer)workSpace);
        }
        workSpace.limit(workSize);
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnGetRNNTrainingReserveSize((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnRNNStruct)this.cudnnContext.rnnDesc, (int)((int)timeSeriesLength), (PointerPointer)this.xDesc, (SizeTPointer)this.sizeInBytes));
        long reserveSize = this.sizeInBytes.get(0L);
        if (reserveSize > this.reserveSpace.capacity()) {
            this.reserveSpace.deallocate();
            this.reserveSpace = new BaseCudnnHelper.DataCache(reserveSize);
        }
        this.reserveSpace.limit(reserveSize);
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareActionAllWrite(new INDArray[]{x, linInputWeights, linRecurrentWeights, linBiases, prevAct, prevMemCell, outputActivations, finalMemCellState, finalStepActivations});
        Pointer xData = allocator.getPointer(x, context);
        Pointer linInputWeightsData = allocator.getPointer(linInputWeights, context);
        Pointer linRecurrentWeightsData = allocator.getPointer(linRecurrentWeights, context);
        Pointer linBiasesData = allocator.getPointer(linBiases, context);
        Pointer prevActData = allocator.getPointer(prevAct, context);
        Pointer prevMemCellData = allocator.getPointer(prevMemCell, context);
        Pointer outputActivationsData = allocator.getPointer(outputActivations, context);
        Pointer finalMemCellStateData = allocator.getPointer(finalMemCellState, context);
        Pointer finalTimeStepActivationsData = allocator.getPointer(finalStepActivations, context);
        cuda.CUstream_st stream = new cuda.CUstream_st((Pointer)context.getOldStream());
        CudnnLSTMHelper.checkCudnn(cudnn.cudnnSetStream((cudnn.cudnnContext)this.cudnnContext, (cuda.CUstream_st)stream));
        CudnnLSTMHelper.checkCuda(cuda.cudaMemsetAsync((Pointer)this.weightsSpace, (int)0, (long)this.weightsSpace.limit(), (cuda.CUstream_st)stream));
        int[] dataType = new int[1];
        int[] format = new int[1];
        int[] nbDims = new int[1];
        int[] filterDimA = new int[3];
        Pointer linLayerMat = new Pointer();
        Pointer linLayerBias = new Pointer();
        for (int layerNum = 0; layerNum < 1; ++layerNum) {
            for (int linLayerID = 0; linLayerID < 8; ++linLayerID) {
                CudnnLSTMHelper.checkCudnn(cudnn.cudnnGetRNNLinLayerMatrixParams((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnRNNStruct)this.cudnnContext.rnnDesc, (int)layerNum, (cudnn.cudnnTensorStruct)xDesc0, (cudnn.cudnnFilterStruct)this.cudnnContext.wDesc, (Pointer)this.weightsSpace, (int)linLayerID, (cudnn.cudnnFilterStruct)this.cudnnContext.linLayerMatDesc, (Pointer)linLayerMat));
                CudnnLSTMHelper.checkCudnn(cudnn.cudnnGetFilterNdDescriptor((cudnn.cudnnFilterStruct)this.cudnnContext.linLayerMatDesc, (int)3, (int[])dataType, (int[])format, (int[])nbDims, (int[])filterDimA));
                CudnnLSTMHelper.checkCudnn(cudnn.cudnnGetRNNLinLayerBiasParams((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnRNNStruct)this.cudnnContext.rnnDesc, (int)layerNum, (cudnn.cudnnTensorStruct)xDesc0, (cudnn.cudnnFilterStruct)this.cudnnContext.wDesc, (Pointer)this.weightsSpace, (int)linLayerID, (cudnn.cudnnFilterStruct)this.cudnnContext.linLayerBiasDesc, (Pointer)linLayerBias));
                CudnnLSTMHelper.checkCudnn(cudnn.cudnnGetFilterNdDescriptor((cudnn.cudnnFilterStruct)this.cudnnContext.linLayerBiasDesc, (int)3, (int[])dataType, (int[])format, (int[])nbDims, (int[])filterDimA));
                int position = 0;
                long size = 0L;
                Pointer data = null;
                switch (linLayerID) {
                    case 0: {
                        data = linInputWeightsData;
                        position = 3;
                        size = inputLayerSize;
                        break;
                    }
                    case 1: {
                        data = linInputWeightsData;
                        position = 1;
                        size = inputLayerSize;
                        break;
                    }
                    case 2: {
                        data = linInputWeightsData;
                        position = 0;
                        size = inputLayerSize;
                        break;
                    }
                    case 3: {
                        data = linInputWeightsData;
                        position = 2;
                        size = inputLayerSize;
                        break;
                    }
                    case 4: {
                        data = linRecurrentWeightsData;
                        position = 3;
                        size = hiddenLayerSize;
                        break;
                    }
                    case 5: {
                        data = linRecurrentWeightsData;
                        position = 1;
                        size = hiddenLayerSize;
                        break;
                    }
                    case 6: {
                        data = linRecurrentWeightsData;
                        position = 0;
                        size = hiddenLayerSize;
                        break;
                    }
                    case 7: {
                        data = linRecurrentWeightsData;
                        position = 2;
                        size = hiddenLayerSize;
                        break;
                    }
                    default: {
                        throw new RuntimeException();
                    }
                }
                CudnnLSTMHelper.checkCuda(cuda.cudaMemcpyAsync((Pointer)linLayerMat, (Pointer)data.position((long)position * size * hiddenLayerSize * (long)this.dataTypeSize), (long)(size * hiddenLayerSize * (long)this.dataTypeSize), (int)3, (cuda.CUstream_st)stream));
                if (linLayerID >= 4) continue;
                CudnnLSTMHelper.checkCuda(cuda.cudaMemcpyAsync((Pointer)linLayerBias, (Pointer)linBiasesData.position((long)position * hiddenLayerSize * (long)this.dataTypeSize), (long)(hiddenLayerSize * (long)this.dataTypeSize), (int)3, (cuda.CUstream_st)stream));
            }
        }
        if (training) {
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnRNNForwardTraining((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnRNNStruct)this.cudnnContext.rnnDesc, (int)((int)timeSeriesLength), (PointerPointer)this.xDesc, (Pointer)xData, (cudnn.cudnnTensorStruct)this.cudnnContext.hxDesc, (Pointer)prevActData, (cudnn.cudnnTensorStruct)this.cudnnContext.cxDesc, (Pointer)prevMemCellData, (cudnn.cudnnFilterStruct)this.cudnnContext.wDesc, (Pointer)this.weightsSpace, (PointerPointer)this.yDesc, (Pointer)outputActivationsData, (cudnn.cudnnTensorStruct)this.cudnnContext.hyDesc, (Pointer)finalTimeStepActivationsData, (cudnn.cudnnTensorStruct)this.cudnnContext.cyDesc, (Pointer)finalMemCellStateData, (Pointer)workSpace, (long)workSpace.limit(), (Pointer)this.reserveSpace, (long)this.reserveSpace.limit()));
        } else {
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnRNNForwardInference((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnRNNStruct)this.cudnnContext.rnnDesc, (int)((int)timeSeriesLength), (PointerPointer)this.xDesc, (Pointer)xData, (cudnn.cudnnTensorStruct)this.cudnnContext.hxDesc, (Pointer)prevActData, (cudnn.cudnnTensorStruct)this.cudnnContext.cxDesc, (Pointer)prevMemCellData, (cudnn.cudnnFilterStruct)this.cudnnContext.wDesc, (Pointer)this.weightsSpace, (PointerPointer)this.yDesc, (Pointer)outputActivationsData, (cudnn.cudnnTensorStruct)this.cudnnContext.hyDesc, (Pointer)finalTimeStepActivationsData, (cudnn.cudnnTensorStruct)this.cudnnContext.cyDesc, (Pointer)finalMemCellStateData, (Pointer)workSpace, (long)workSpace.limit()));
        }
        allocator.getFlowController().registerActionAllWrite(context, new INDArray[]{x, linInputWeights, linRecurrentWeights, linBiases, prevAct, prevMemCell, outputActivations, finalMemCellState, finalStepActivations});
        toReturn.fwdPassOutput = outputActivations.permute(new int[]{1, 2, 0});
        toReturn.lastAct = finalStepActivations;
        toReturn.lastMemCell = finalMemCellState;
        toReturn.prevAct = prevAct;
        toReturn.prevMemCell = prevMemCell;
        return toReturn;
    }

    public Map<String, Long> helperMemoryUse() {
        HashMap<String, Long> memUse = new HashMap<String, Long>();
        memUse.put("stateStace", this.stateSpace.capacity());
        memUse.put("reserveSpace", this.reserveSpace.capacity());
        memUse.put("weightsSpace", this.weightsSpace.capacity());
        return memUse;
    }

    private static class CudnnLSTMContext
    extends BaseCudnnHelper.CudnnContext {
        private cudnn.cudnnTensorStruct hxDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnTensorStruct cxDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnTensorStruct hyDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnTensorStruct cyDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnTensorStruct dhxDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnTensorStruct dcxDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnTensorStruct dhyDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnTensorStruct dcyDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnFilterStruct wDesc = new cudnn.cudnnFilterStruct();
        private cudnn.cudnnFilterStruct dwDesc = new cudnn.cudnnFilterStruct();
        private cudnn.cudnnFilterStruct linLayerMatDesc = new cudnn.cudnnFilterStruct();
        private cudnn.cudnnFilterStruct linLayerBiasDesc = new cudnn.cudnnFilterStruct();
        private cudnn.cudnnRNNStruct rnnDesc = new cudnn.cudnnRNNStruct();
        private cudnn.cudnnDropoutStruct dropoutDesc = new cudnn.cudnnDropoutStruct();
        private cudnn.cudnnActivationStruct activationDesc = new cudnn.cudnnActivationStruct();

        public CudnnLSTMContext() {
            this.createHandles();
            this.deallocator(new Deallocator(this));
        }

        public CudnnLSTMContext(CudnnLSTMContext c) {
            super(c);
            this.hxDesc = new cudnn.cudnnTensorStruct((Pointer)c.hxDesc);
            this.cxDesc = new cudnn.cudnnTensorStruct((Pointer)c.cxDesc);
            this.hyDesc = new cudnn.cudnnTensorStruct((Pointer)c.hyDesc);
            this.cyDesc = new cudnn.cudnnTensorStruct((Pointer)c.cyDesc);
            this.dhxDesc = new cudnn.cudnnTensorStruct((Pointer)c.dhxDesc);
            this.dcxDesc = new cudnn.cudnnTensorStruct((Pointer)c.dcxDesc);
            this.dhyDesc = new cudnn.cudnnTensorStruct((Pointer)c.dhyDesc);
            this.dcyDesc = new cudnn.cudnnTensorStruct((Pointer)c.dcyDesc);
            this.wDesc = new cudnn.cudnnFilterStruct((Pointer)c.wDesc);
            this.dwDesc = new cudnn.cudnnFilterStruct((Pointer)c.dwDesc);
            this.linLayerMatDesc = new cudnn.cudnnFilterStruct((Pointer)c.linLayerMatDesc);
            this.linLayerBiasDesc = new cudnn.cudnnFilterStruct((Pointer)c.linLayerBiasDesc);
            this.rnnDesc = new cudnn.cudnnRNNStruct((Pointer)c.rnnDesc);
            this.dropoutDesc = new cudnn.cudnnDropoutStruct((Pointer)c.dropoutDesc);
            this.activationDesc = new cudnn.cudnnActivationStruct((Pointer)c.activationDesc);
        }

        @Override
        protected void createHandles() {
            super.createHandles();
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.hxDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.cxDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.hyDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.cyDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.dhxDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.dcxDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.dhyDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.dcyDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateFilterDescriptor((cudnn.cudnnFilterStruct)this.wDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateFilterDescriptor((cudnn.cudnnFilterStruct)this.dwDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateFilterDescriptor((cudnn.cudnnFilterStruct)this.linLayerMatDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateFilterDescriptor((cudnn.cudnnFilterStruct)this.linLayerBiasDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateRNNDescriptor((cudnn.cudnnRNNStruct)this.rnnDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateDropoutDescriptor((cudnn.cudnnDropoutStruct)this.dropoutDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateActivationDescriptor((cudnn.cudnnActivationStruct)this.activationDesc));
        }

        @Override
        protected void destroyHandles() {
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyActivationDescriptor((cudnn.cudnnActivationStruct)this.activationDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyDropoutDescriptor((cudnn.cudnnDropoutStruct)this.dropoutDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyRNNDescriptor((cudnn.cudnnRNNStruct)this.rnnDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyFilterDescriptor((cudnn.cudnnFilterStruct)this.wDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyFilterDescriptor((cudnn.cudnnFilterStruct)this.dwDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyFilterDescriptor((cudnn.cudnnFilterStruct)this.linLayerMatDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyFilterDescriptor((cudnn.cudnnFilterStruct)this.linLayerBiasDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.hxDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.cxDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.hyDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.cyDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.dhxDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.dcxDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.dhyDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.dcyDesc));
            super.destroyHandles();
        }

        private static class Deallocator
        extends CudnnLSTMContext
        implements Pointer.Deallocator {
            Deallocator(CudnnLSTMContext c) {
                super(c);
            }

            public void deallocate() {
                this.destroyHandles();
            }
        }
    }
}

