public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper
BaseCudnnHelper.CudnnContext, BaseCudnnHelper.DataCache, BaseCudnnHelper.TensorArray| Modifier and Type | Field and Description |
|---|---|
protected int |
batchNormMode |
alpha, beta, dataType, dataTypeSize, sizeInBytes, TENSOR_FORMAT| Constructor and Description |
|---|
CudnnBatchNormalizationHelper() |
| Modifier and Type | Method and Description |
|---|---|
org.nd4j.linalg.primitives.Pair<org.deeplearning4j.nn.gradient.Gradient,INDArray> |
backpropGradient(INDArray input,
INDArray epsilon,
int[] shape,
INDArray gamma,
INDArray dGammaView,
INDArray dBetaView,
double eps,
org.deeplearning4j.nn.workspace.LayerWorkspaceMgr layerWorkspaceMgr) |
boolean |
checkSupported(double eps) |
INDArray |
getMeanCache() |
INDArray |
getVarCache() |
Map<String,Long> |
helperMemoryUse() |
INDArray |
preOutput(INDArray x,
boolean training,
int[] shape,
INDArray gamma,
INDArray beta,
INDArray mean,
INDArray var,
double decay,
double eps,
org.deeplearning4j.nn.workspace.LayerWorkspaceMgr workspaceMgr) |
adaptForTensorDescr, checkCuda, checkCudnn, checkSupported, toCudnnDataTypeprotected final int batchNormMode
public boolean checkSupported(double eps)
checkSupported in interface org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelperpublic org.nd4j.linalg.primitives.Pair<org.deeplearning4j.nn.gradient.Gradient,INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, INDArray dGammaView, INDArray dBetaView, double eps, org.deeplearning4j.nn.workspace.LayerWorkspaceMgr layerWorkspaceMgr)
backpropGradient in interface org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelperpublic INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var, double decay, double eps, org.deeplearning4j.nn.workspace.LayerWorkspaceMgr workspaceMgr)
preOutput in interface org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelperpublic INDArray getMeanCache()
getMeanCache in interface org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelperpublic INDArray getVarCache()
getVarCache in interface org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelperCopyright © 2018. All rights reserved.