Interface FunctionBatchNorm<T extends Tensor<T>>

All Superinterfaces:
BatchNorm, Function<T>
All Known Implementing Classes:
FunctionBatchNorm_F32, FunctionBatchNorm_F64, SpatialBatchNorm_F32, SpatialBatchNorm_F64

public interface FunctionBatchNorm<T extends Tensor<T>> extends Function<T>, BatchNorm

Implementation of a forward only Batch Normalization. It applies a previously computed linear transform which will ensure that the training data will have an output with zero mean and standard deviation (stdev) of one. The optional gamma and beta transform can also be applied.

See BatchNorm for a general discussion of Batch Normalization

  • Method Details

    • forward

      void forward(T input, T output)

      Applies batch normalization to each variable in the input.

      Either two or four variables are stored in the parameter tensor as interleaved variables. If BatchNorm.hasGammaBeta() returns true then mean, variance, gamma, and beta are saved. Otherwise just mean, and variance are saved. These are also the order in which variables are interleaved together.

       Summary Table
       -------------------------------------------------
       Input   shape = (N, d[i], ... , d[k])
       Output  shape = (N, d[i], ... , d[k])
       Params  shape = (d[i], ... , d[k], M)
       -------------------------------------------------
       N    = Size of mini-batch
       d[i] = length of a dimension
       M    = Number of parameters.  2 or 4 if gamma-beta is being used.
             in order of: mean, variance  OR mean, variance, gamma, beta
       

      NOTE: Interleaving is used instead of multiple tensors to improve memory locality, which reduces cache misses.

      Specified by:
      forward in interface Function<T extends Tensor<T>>
      Parameters:
      input - Input tensor. Tensor with a shape of (N, d[i], ... , d[k]), where N is mini-batch size
      output - Output tensor. Same shape as input tensor Modified.
    • setParameters

      void setParameters(List<T> parameters)
      See forward(T, T) for a description of parameters.
      Specified by:
      setParameters in interface Function<T extends Tensor<T>>
      Parameters:
      parameters - Variable tensor. (d[i], ... , d[k], M), where M is 2 or 4. Not modified.