/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.normalization;

import java.util.HashMap;
import java.util.Map;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.cuda;
import org.bytedeco.javacpp.cudnn;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseCudnnHelper;
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.JCublasNDArray;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudnnBatchNormalizationHelper
extends BaseCudnnHelper
implements BatchNormalizationHelper {
    private static final Logger log = LoggerFactory.getLogger(CudnnBatchNormalizationHelper.class);
    protected final int batchNormMode = 1;
    private CudnnBatchNormalizationContext cudnnContext = new CudnnBatchNormalizationContext();
    private INDArray meanCache;
    private INDArray varCache;
    private double eps;

    public boolean checkSupported(double eps) {
        boolean supported = this.checkSupported();
        if (eps < 1.0E-5) {
            supported = false;
            log.warn("Not supported: eps < CUDNN_BN_MIN_EPSILON (" + eps + " < " + 1.0E-5 + ")");
        }
        return supported;
    }

    public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) {
        this.eps = eps;
        int miniBatch = (int)input.size(0);
        int depth = (int)input.size(1);
        int inH = (int)input.size(2);
        int inW = (int)input.size(3);
        boolean isHalf = Nd4j.dataType() == DataBuffer.Type.HALF;
        INDArray gammaOrig = null;
        INDArray dGammaViewOrig = null;
        INDArray dBetaViewOrig = null;
        if (isHalf) {
            gammaOrig = gamma;
            dGammaViewOrig = dGammaView;
            dBetaViewOrig = dBetaView;
            gamma = gamma.convertToFloats();
            dGammaView = dGammaView.convertToFloats();
            dBetaView = dBetaView.convertToFloats();
        }
        DefaultGradient retGradient = new DefaultGradient();
        if (!Shape.hasDefaultStridesForShape((INDArray)epsilon)) {
            epsilon = epsilon.dup('c');
        }
        int[] srcStride = ArrayUtil.toInts((long[])input.stride());
        int[] deltaStride = ArrayUtil.toInts((long[])epsilon.stride());
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)srcStride[0], (int)srcStride[1], (int)srcStride[2], (int)srcStride[3]));
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)deltaStride[0], (int)deltaStride[1], (int)deltaStride[2], (int)deltaStride[3]));
        INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized((Enum)ArrayType.ACTIVATION_GRAD, new int[]{miniBatch, depth, inH, inW}, 'c');
        int[] dstStride = ArrayUtil.toInts((long[])nextEpsilon.stride());
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)dstStride[0], (int)dstStride[1], (int)dstStride[2], (int)dstStride[3]));
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptor((cudnn.cudnnTensorStruct)this.cudnnContext.gammaBetaTensorDesc, (int)0, (int)CudnnBatchNormalizationHelper.toCudnnDataType(gamma.data().dataType()), (int)shape[0], (int)shape[1], (int)(shape.length > 2 ? shape[2] : 1), (int)(shape.length > 3 ? shape[3] : 1)));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareActionAllWrite(new INDArray[]{input, epsilon, nextEpsilon, gamma, dGammaView, dBetaView});
        Pointer srcData = allocator.getPointer(input, context);
        Pointer epsData = allocator.getPointer(epsilon, context);
        Pointer dstData = allocator.getPointer(nextEpsilon, context);
        Pointer gammaData = allocator.getPointer(gamma, context);
        Pointer dGammaData = allocator.getPointer(dGammaView, context);
        Pointer dBetaData = allocator.getPointer(dBetaView, context);
        Pointer meanCacheData = allocator.getPointer(this.meanCache, context);
        Pointer varCacheData = allocator.getPointer(this.varCache, context);
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetStream((cudnn.cudnnContext)this.cudnnContext, (cuda.CUstream_st)new cuda.CUstream_st((Pointer)context.getOldStream())));
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnBatchNormalizationBackward((cudnn.cudnnContext)this.cudnnContext, (int)1, (Pointer)this.alpha, (Pointer)this.beta, (Pointer)this.alpha, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (Pointer)epsData, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData, (cudnn.cudnnTensorStruct)this.cudnnContext.gammaBetaTensorDesc, (Pointer)gammaData, (Pointer)dGammaData, (Pointer)dBetaData, (double)eps, (Pointer)meanCacheData, (Pointer)varCacheData));
        allocator.getFlowController().registerActionAllWrite(context, new INDArray[]{input, epsilon, nextEpsilon, gamma, dGammaView, dBetaView});
        retGradient.setGradientFor("gamma", dGammaView);
        retGradient.setGradientFor("beta", dBetaView);
        context.syncOldStream();
        if (isHalf) {
            gammaOrig.assign(((JCublasNDArray)gamma).convertToHalfs());
            dGammaViewOrig.assign(((JCublasNDArray)dGammaView).convertToHalfs());
            dBetaViewOrig.assign(((JCublasNDArray)dBetaView).convertToHalfs());
        }
        return new Pair((Object)retGradient, (Object)nextEpsilon);
    }

    public INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) {
        this.eps = eps;
        boolean isHalf = Nd4j.dataType() == DataBuffer.Type.HALF;
        INDArray origGamma = gamma;
        INDArray origBeta = beta;
        INDArray origMean = mean;
        INDArray origVar = var;
        if (isHalf) {
            gamma = gamma.convertToFloats();
            beta = beta.convertToFloats();
            mean = mean.convertToFloats();
            var = var.convertToFloats();
        }
        decay = 0.0;
        int miniBatch = (int)x.size(0);
        int inDepth = (int)x.size(1);
        int inH = (int)x.size(2);
        int inW = (int)x.size(3);
        int[] srcStride = ArrayUtil.toInts((long[])x.stride());
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)miniBatch, (int)inDepth, (int)inH, (int)inW, (int)srcStride[0], (int)srcStride[1], (int)srcStride[2], (int)srcStride[3]));
        INDArray activations = workspaceMgr.createUninitialized((Enum)ArrayType.ACTIVATIONS, new int[]{miniBatch, inDepth, inH, inW}, 'c');
        int[] dstStride = ArrayUtil.toInts((long[])activations.stride());
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)miniBatch, (int)inDepth, (int)inH, (int)inW, (int)dstStride[0], (int)dstStride[1], (int)dstStride[2], (int)dstStride[3]));
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptor((cudnn.cudnnTensorStruct)this.cudnnContext.gammaBetaTensorDesc, (int)0, (int)CudnnBatchNormalizationHelper.toCudnnDataType(mean.data().dataType()), (int)shape[0], (int)shape[1], (int)(shape.length > 2 ? shape[2] : 1), (int)(shape.length > 3 ? shape[3] : 1)));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareActionAllWrite(new INDArray[]{x, activations, gamma, beta, mean, var});
        Pointer srcData = allocator.getPointer(x, context);
        Pointer dstData = allocator.getPointer(activations, context);
        Pointer gammaData = allocator.getPointer(gamma, context);
        Pointer betaData = allocator.getPointer(beta, context);
        Pointer meanData = allocator.getPointer(mean, context);
        Pointer varData = allocator.getPointer(var, context);
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetStream((cudnn.cudnnContext)this.cudnnContext, (cuda.CUstream_st)new cuda.CUstream_st((Pointer)context.getOldStream())));
        if (training) {
            Throwable throwable;
            MemoryWorkspace ws;
            if (this.meanCache == null || this.meanCache.length() < mean.length()) {
                this.meanCache = Nd4j.createUninitializedDetached((int)((int)mean.length()));
                if (Nd4j.dataType() == DataBuffer.Type.HALF) {
                    ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                    throwable = null;
                    try {
                        this.meanCache = this.meanCache.convertToFloats();
                    }
                    catch (Throwable throwable2) {
                        throwable = throwable2;
                        throw throwable2;
                    }
                    finally {
                        if (ws != null) {
                            if (throwable != null) {
                                try {
                                    ws.close();
                                }
                                catch (Throwable throwable3) {
                                    throwable.addSuppressed(throwable3);
                                }
                            } else {
                                ws.close();
                            }
                        }
                    }
                }
            }
            if (this.varCache == null || this.varCache.length() < mean.length()) {
                this.varCache = Nd4j.createUninitializedDetached((int)((int)mean.length()));
                if (Nd4j.dataType() == DataBuffer.Type.HALF) {
                    ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                    throwable = null;
                    try {
                        this.varCache = this.varCache.convertToFloats();
                    }
                    catch (Throwable throwable4) {
                        throwable = throwable4;
                        throw throwable4;
                    }
                    finally {
                        if (ws != null) {
                            if (throwable != null) {
                                try {
                                    ws.close();
                                }
                                catch (Throwable throwable5) {
                                    throwable.addSuppressed(throwable5);
                                }
                            } else {
                                ws.close();
                            }
                        }
                    }
                }
            }
            Pointer meanCacheData = allocator.getPointer(this.meanCache, context);
            Pointer varCacheData = allocator.getPointer(this.varCache, context);
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnBatchNormalizationForwardTraining((cudnn.cudnnContext)this.cudnnContext, (int)1, (Pointer)this.alpha, (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData, (cudnn.cudnnTensorStruct)this.cudnnContext.gammaBetaTensorDesc, (Pointer)gammaData, (Pointer)betaData, (double)decay, (Pointer)meanData, (Pointer)varData, (double)eps, (Pointer)meanCacheData, (Pointer)varCacheData));
        } else {
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnBatchNormalizationForwardInference((cudnn.cudnnContext)this.cudnnContext, (int)1, (Pointer)this.alpha, (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData, (cudnn.cudnnTensorStruct)this.cudnnContext.gammaBetaTensorDesc, (Pointer)gammaData, (Pointer)betaData, (Pointer)meanData, (Pointer)varData, (double)eps));
        }
        allocator.getFlowController().registerActionAllWrite(context, new INDArray[]{x, activations, gamma, beta, mean, var});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            context.syncOldStream();
        }
        context.syncOldStream();
        if (training) {
            AtomicAllocator.getInstance().getAllocationPoint(this.meanCache).tickDeviceWrite();
            AtomicAllocator.getInstance().getAllocationPoint(this.varCache).tickDeviceWrite();
        }
        if (training && isHalf) {
            origMean.assign(((JCublasNDArray)mean).convertToHalfs());
            origVar.assign(((JCublasNDArray)var).convertToHalfs());
            origGamma.assign(((JCublasNDArray)gamma).convertToHalfs());
            origBeta.assign(((JCublasNDArray)beta).convertToHalfs());
        }
        return activations;
    }

    public INDArray getMeanCache() {
        if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            return this.meanCache.convertToHalfs();
        }
        return this.meanCache;
    }

    public INDArray getVarCache() {
        INDArray ret;
        if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            INDArray vc = this.varCache.convertToHalfs();
            ret = vc.mul(vc).rdivi((Number)1.0).subi((Number)this.eps);
        } else {
            ret = this.varCache.mul(this.varCache).rdivi((Number)1.0).subi((Number)this.eps);
        }
        if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            return ret.convertToHalfs();
        }
        return ret;
    }

    public Map<String, Long> helperMemoryUse() {
        HashMap<String, Long> memUse = new HashMap<String, Long>();
        memUse.put("meanCache", this.meanCache == null ? 0L : this.meanCache.length() * (long)this.meanCache.data().getElementSize());
        memUse.put("varCache", this.varCache == null ? 0L : this.varCache.length() * (long)this.varCache.data().getElementSize());
        return memUse;
    }

    private static class CudnnBatchNormalizationContext
    extends BaseCudnnHelper.CudnnContext {
        private cudnn.cudnnTensorStruct srcTensorDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnTensorStruct dstTensorDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnTensorStruct deltaTensorDesc = new cudnn.cudnnTensorStruct();
        private cudnn.cudnnTensorStruct gammaBetaTensorDesc = new cudnn.cudnnTensorStruct();

        public CudnnBatchNormalizationContext() {
            this.createHandles();
            this.deallocator(new Deallocator(this));
        }

        public CudnnBatchNormalizationContext(CudnnBatchNormalizationContext c) {
            super(c);
            this.srcTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.srcTensorDesc);
            this.dstTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.dstTensorDesc);
            this.deltaTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.deltaTensorDesc);
            this.gammaBetaTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.gammaBetaTensorDesc);
        }

        @Override
        protected void createHandles() {
            super.createHandles();
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.srcTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.dstTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.deltaTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.gammaBetaTensorDesc));
        }

        @Override
        protected void destroyHandles() {
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.srcTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.dstTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.deltaTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.gammaBetaTensorDesc));
            super.destroyHandles();
        }

        private static class Deallocator
        extends CudnnBatchNormalizationContext
        implements Pointer.Deallocator {
            Deallocator(CudnnBatchNormalizationContext c) {
                super(c);
            }

            public void deallocate() {
                this.destroyHandles();
            }
        }
    }
}

