/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.cuda.dropout;

import com.jakewharton.byteunits.BinaryByteUnit;
import org.bytedeco.cuda.cudart.CUstream_st;
import org.bytedeco.cuda.cudnn.cudnnContext;
import org.bytedeco.cuda.cudnn.cudnnDropoutStruct;
import org.bytedeco.cuda.cudnn.cudnnTensorStruct;
import org.bytedeco.cuda.global.cudnn;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.deeplearning4j.cuda.BaseCudnnHelper;
import org.deeplearning4j.nn.conf.dropout.DropoutHelper;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudnnDropoutHelper
extends BaseCudnnHelper
implements DropoutHelper {
    private static final Logger log = LoggerFactory.getLogger(CudnnDropoutHelper.class);
    private CudnnDropoutContext cudnnContext = new CudnnDropoutContext();
    private boolean initializedDescriptor = false;
    private BaseCudnnHelper.DataCache rngStates;
    private BaseCudnnHelper.DataCache mask;
    private SizeTPointer stateSizeBytesPtr;
    private SizeTPointer reserveSizeBytesPtr;
    private float lastInitializedP;

    public CudnnDropoutHelper(DataType dataType) {
        super(dataType);
    }

    public void applyDropout(INDArray input, INDArray resultArray, double dropoutInputRetainProb) {
        float p = (float)(1.0 - dropoutInputRetainProb);
        int[] inShape = CudnnDropoutHelper.adaptForTensorDescr(ArrayUtil.toInts((long[])input.shape()));
        int[] inStride = CudnnDropoutHelper.adaptForTensorDescr(ArrayUtil.toInts((long[])input.stride()));
        CudnnDropoutHelper.checkCudnn(cudnn.cudnnSetTensorNdDescriptor((cudnnTensorStruct)this.cudnnContext.xTensorDesc, (int)this.dataType, (int)inShape.length, (int[])inShape, (int[])inStride));
        int[] outShape = CudnnDropoutHelper.adaptForTensorDescr(ArrayUtil.toInts((long[])resultArray.shape()));
        int[] outStride = CudnnDropoutHelper.adaptForTensorDescr(ArrayUtil.toInts((long[])resultArray.stride()));
        CudnnDropoutHelper.checkCudnn(cudnn.cudnnSetTensorNdDescriptor((cudnnTensorStruct)this.cudnnContext.yTensorDesc, (int)this.dataType, (int)outShape.length, (int[])outShape, (int[])outStride));
        if (this.stateSizeBytesPtr == null) {
            this.stateSizeBytesPtr = new SizeTPointer(1L);
            this.reserveSizeBytesPtr = new SizeTPointer(1L);
        }
        CudnnDropoutHelper.checkCudnn(cudnn.cudnnDropoutGetStatesSize((cudnnContext)this.cudnnContext, (SizeTPointer)this.stateSizeBytesPtr));
        long rngStateSizeBytes = this.stateSizeBytesPtr.get();
        CudnnDropoutHelper.checkCudnn(cudnn.cudnnDropoutGetReserveSpaceSize((cudnnTensorStruct)this.cudnnContext.xTensorDesc, (SizeTPointer)this.reserveSizeBytesPtr));
        long maskReserveSizeBytes = this.reserveSizeBytesPtr.get();
        if (this.rngStates == null || this.rngStates.capacity() < rngStateSizeBytes) {
            if (log.isTraceEnabled()) {
                if (this.rngStates == null) {
                    log.trace("CudnnDropoutHelper: Allocating intial RNG states workspace of size {} ({})", (Object)rngStateSizeBytes, (Object)BinaryByteUnit.format((long)rngStateSizeBytes, (String)"#.00"));
                } else {
                    log.trace("CudnnDropoutHelper: Deallocating RNG states of size {} ({}), allocating new workspace of size {} ({})", new Object[]{this.rngStates.capacity(), BinaryByteUnit.format((long)this.rngStates.capacity(), (String)"#.00"), rngStateSizeBytes, BinaryByteUnit.format((long)rngStateSizeBytes, (String)"#.00")});
                }
            }
            if (this.rngStates != null) {
                this.rngStates.deallocate();
            }
            this.rngStates = new BaseCudnnHelper.DataCache(rngStateSizeBytes);
            this.initializedDescriptor = false;
        }
        if (this.mask == null || this.mask.capacity() < maskReserveSizeBytes) {
            if (log.isTraceEnabled()) {
                if (this.mask == null) {
                    log.trace("CudnnDropoutHelper: Allocating intial mask array of size {} ({})", (Object)maskReserveSizeBytes, (Object)BinaryByteUnit.format((long)maskReserveSizeBytes, (String)"#.00"));
                } else {
                    log.trace("CudnnDropoutHelper: Deallocating mask array of size {} ({}), allocating new mask array of size {} ({})", new Object[]{this.mask.capacity(), BinaryByteUnit.format((long)this.mask.capacity(), (String)"#.00"), maskReserveSizeBytes, BinaryByteUnit.format((long)maskReserveSizeBytes, (String)"#.00")});
                }
            }
            if (this.mask != null) {
                this.mask.deallocate();
            }
            this.mask = new BaseCudnnHelper.DataCache(maskReserveSizeBytes);
        }
        if (!this.initializedDescriptor || p != this.lastInitializedP) {
            if (log.isTraceEnabled()) {
                log.trace("CudnnDropoutHelper: (re)initializing dropout descriptor");
            }
            long seed = Nd4j.getRandom().nextLong();
            this.lastInitializedP = p;
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnSetDropoutDescriptor((cudnnDropoutStruct)this.cudnnContext.dropoutDesc, (cudnnContext)this.cudnnContext, (float)p, (Pointer)this.rngStates, (long)this.rngStates.capacity(), (long)seed));
            this.initializedDescriptor = true;
        }
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(input, new INDArray[]{resultArray});
        Pointer xPtr = allocator.getPointer(input, context);
        Pointer yPtr = allocator.getPointer(resultArray, context);
        CudnnDropoutHelper.checkCudnn(cudnn.cudnnSetStream((cudnnContext)this.cudnnContext, (CUstream_st)new CUstream_st(context.getCublasStream())));
        CudnnDropoutHelper.checkCudnn(cudnn.cudnnDropoutForward((cudnnContext)this.cudnnContext, (cudnnDropoutStruct)this.cudnnContext.dropoutDesc, (cudnnTensorStruct)this.cudnnContext.xTensorDesc, (Pointer)xPtr, (cudnnTensorStruct)this.cudnnContext.yTensorDesc, (Pointer)yPtr, (Pointer)this.mask, (long)this.mask.capacity()));
        allocator.registerAction(context, input, new INDArray[]{resultArray});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            context.syncOldStream();
        }
    }

    public void backprop(INDArray gradAtOutput, INDArray gradAtInput) {
        int[] gradAtOutShape = CudnnDropoutHelper.adaptForTensorDescr(ArrayUtil.toInts((long[])gradAtOutput.shape()));
        int[] gradAtOutStride = CudnnDropoutHelper.adaptForTensorDescr(ArrayUtil.toInts((long[])gradAtOutput.stride()));
        CudnnDropoutHelper.checkCudnn(cudnn.cudnnSetTensorNdDescriptor((cudnnTensorStruct)this.cudnnContext.dyTensorDesc, (int)this.dataType, (int)gradAtOutShape.length, (int[])gradAtOutShape, (int[])gradAtOutStride));
        int[] gradAtInShape = CudnnDropoutHelper.adaptForTensorDescr(ArrayUtil.toInts((long[])gradAtInput.shape()));
        int[] gradAtInStride = CudnnDropoutHelper.adaptForTensorDescr(ArrayUtil.toInts((long[])gradAtInput.stride()));
        CudnnDropoutHelper.checkCudnn(cudnn.cudnnSetTensorNdDescriptor((cudnnTensorStruct)this.cudnnContext.dxTensorDesc, (int)this.dataType, (int)gradAtInShape.length, (int[])gradAtInShape, (int[])gradAtInStride));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(gradAtOutput, new INDArray[]{gradAtInput});
        Pointer dyPtr = allocator.getPointer(gradAtOutput, context);
        Pointer dxPtr = allocator.getPointer(gradAtInput, context);
        CudnnDropoutHelper.checkCudnn(cudnn.cudnnDropoutBackward((cudnnContext)this.cudnnContext, (cudnnDropoutStruct)this.cudnnContext.dropoutDesc, (cudnnTensorStruct)this.cudnnContext.dyTensorDesc, (Pointer)dyPtr, (cudnnTensorStruct)this.cudnnContext.dxTensorDesc, (Pointer)dxPtr, (Pointer)this.mask, (long)this.mask.capacity()));
        allocator.registerAction(context, gradAtOutput, new INDArray[]{gradAtInput});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            context.syncOldStream();
        }
    }

    public CudnnDropoutContext getCudnnContext() {
        return this.cudnnContext;
    }

    public boolean isInitializedDescriptor() {
        return this.initializedDescriptor;
    }

    public BaseCudnnHelper.DataCache getRngStates() {
        return this.rngStates;
    }

    public BaseCudnnHelper.DataCache getMask() {
        return this.mask;
    }

    public SizeTPointer getStateSizeBytesPtr() {
        return this.stateSizeBytesPtr;
    }

    public SizeTPointer getReserveSizeBytesPtr() {
        return this.reserveSizeBytesPtr;
    }

    public float getLastInitializedP() {
        return this.lastInitializedP;
    }

    public void setCudnnContext(CudnnDropoutContext cudnnContext2) {
        this.cudnnContext = cudnnContext2;
    }

    public void setInitializedDescriptor(boolean initializedDescriptor) {
        this.initializedDescriptor = initializedDescriptor;
    }

    public void setRngStates(BaseCudnnHelper.DataCache rngStates) {
        this.rngStates = rngStates;
    }

    public void setMask(BaseCudnnHelper.DataCache mask) {
        this.mask = mask;
    }

    public void setStateSizeBytesPtr(SizeTPointer stateSizeBytesPtr) {
        this.stateSizeBytesPtr = stateSizeBytesPtr;
    }

    public void setReserveSizeBytesPtr(SizeTPointer reserveSizeBytesPtr) {
        this.reserveSizeBytesPtr = reserveSizeBytesPtr;
    }

    public void setLastInitializedP(float lastInitializedP) {
        this.lastInitializedP = lastInitializedP;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof CudnnDropoutHelper)) {
            return false;
        }
        CudnnDropoutHelper other = (CudnnDropoutHelper)o;
        if (!other.canEqual(this)) {
            return false;
        }
        CudnnDropoutContext this$cudnnContext = this.getCudnnContext();
        CudnnDropoutContext other$cudnnContext = other.getCudnnContext();
        if (this$cudnnContext == null ? other$cudnnContext != null : !((Object)((Object)this$cudnnContext)).equals((Object)other$cudnnContext)) {
            return false;
        }
        if (this.isInitializedDescriptor() != other.isInitializedDescriptor()) {
            return false;
        }
        BaseCudnnHelper.DataCache this$rngStates = this.getRngStates();
        BaseCudnnHelper.DataCache other$rngStates = other.getRngStates();
        if (this$rngStates == null ? other$rngStates != null : !((Object)((Object)this$rngStates)).equals((Object)other$rngStates)) {
            return false;
        }
        BaseCudnnHelper.DataCache this$mask = this.getMask();
        BaseCudnnHelper.DataCache other$mask = other.getMask();
        if (this$mask == null ? other$mask != null : !((Object)((Object)this$mask)).equals((Object)other$mask)) {
            return false;
        }
        SizeTPointer this$stateSizeBytesPtr = this.getStateSizeBytesPtr();
        SizeTPointer other$stateSizeBytesPtr = other.getStateSizeBytesPtr();
        if (this$stateSizeBytesPtr == null ? other$stateSizeBytesPtr != null : !this$stateSizeBytesPtr.equals(other$stateSizeBytesPtr)) {
            return false;
        }
        SizeTPointer this$reserveSizeBytesPtr = this.getReserveSizeBytesPtr();
        SizeTPointer other$reserveSizeBytesPtr = other.getReserveSizeBytesPtr();
        if (this$reserveSizeBytesPtr == null ? other$reserveSizeBytesPtr != null : !this$reserveSizeBytesPtr.equals(other$reserveSizeBytesPtr)) {
            return false;
        }
        return Float.compare(this.getLastInitializedP(), other.getLastInitializedP()) == 0;
    }

    protected boolean canEqual(Object other) {
        return other instanceof CudnnDropoutHelper;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        CudnnDropoutContext $cudnnContext = this.getCudnnContext();
        result = result * 59 + ($cudnnContext == null ? 43 : ((Object)((Object)$cudnnContext)).hashCode());
        result = result * 59 + (this.isInitializedDescriptor() ? 79 : 97);
        BaseCudnnHelper.DataCache $rngStates = this.getRngStates();
        result = result * 59 + ($rngStates == null ? 43 : ((Object)((Object)$rngStates)).hashCode());
        BaseCudnnHelper.DataCache $mask = this.getMask();
        result = result * 59 + ($mask == null ? 43 : ((Object)((Object)$mask)).hashCode());
        SizeTPointer $stateSizeBytesPtr = this.getStateSizeBytesPtr();
        result = result * 59 + ($stateSizeBytesPtr == null ? 43 : $stateSizeBytesPtr.hashCode());
        SizeTPointer $reserveSizeBytesPtr = this.getReserveSizeBytesPtr();
        result = result * 59 + ($reserveSizeBytesPtr == null ? 43 : $reserveSizeBytesPtr.hashCode());
        result = result * 59 + Float.floatToIntBits(this.getLastInitializedP());
        return result;
    }

    public String toString() {
        return "CudnnDropoutHelper(cudnnContext=" + (Object)((Object)this.getCudnnContext()) + ", initializedDescriptor=" + this.isInitializedDescriptor() + ", rngStates=" + (Object)((Object)this.getRngStates()) + ", mask=" + (Object)((Object)this.getMask()) + ", stateSizeBytesPtr=" + this.getStateSizeBytesPtr() + ", reserveSizeBytesPtr=" + this.getReserveSizeBytesPtr() + ", lastInitializedP=" + this.getLastInitializedP() + ")";
    }

    private static class CudnnDropoutContext
    extends BaseCudnnHelper.CudnnContext {
        private cudnnTensorStruct xTensorDesc = new cudnnTensorStruct();
        private cudnnTensorStruct dxTensorDesc = new cudnnTensorStruct();
        private cudnnTensorStruct yTensorDesc = new cudnnTensorStruct();
        private cudnnTensorStruct dyTensorDesc = new cudnnTensorStruct();
        private cudnnDropoutStruct dropoutDesc = new cudnnDropoutStruct();

        public CudnnDropoutContext() {
            this.createHandles();
            this.deallocator(new Deallocator(this));
        }

        public CudnnDropoutContext(CudnnDropoutContext c) {
            super(c);
            this.xTensorDesc = new cudnnTensorStruct((Pointer)c.xTensorDesc);
            this.dxTensorDesc = new cudnnTensorStruct((Pointer)c.dxTensorDesc);
            this.yTensorDesc = new cudnnTensorStruct((Pointer)c.yTensorDesc);
            this.dyTensorDesc = new cudnnTensorStruct((Pointer)c.dyTensorDesc);
            this.dropoutDesc = new cudnnDropoutStruct((Pointer)c.dropoutDesc);
        }

        @Override
        protected void createHandles() {
            super.createHandles();
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnnTensorStruct)this.xTensorDesc));
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnnTensorStruct)this.dxTensorDesc));
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnnTensorStruct)this.yTensorDesc));
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnnTensorStruct)this.dyTensorDesc));
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnCreateDropoutDescriptor((cudnnDropoutStruct)this.dropoutDesc));
        }

        @Override
        protected void destroyHandles() {
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnnTensorStruct)this.xTensorDesc));
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnnTensorStruct)this.dxTensorDesc));
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnnTensorStruct)this.yTensorDesc));
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnnTensorStruct)this.dyTensorDesc));
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnDestroyDropoutDescriptor((cudnnDropoutStruct)this.dropoutDesc));
            super.destroyHandles();
        }

        private static class Deallocator
        extends CudnnDropoutContext
        implements Pointer.Deallocator {
            Deallocator(CudnnDropoutContext c) {
                super(c);
            }

            public void deallocate() {
                this.destroyHandles();
            }
        }
    }
}

