Class ResNetV1

java.lang.Object
ai.djl.basicmodelzoo.cv.classification.ResNetV1

public final class ResNetV1 extends Object
ResNetV1 contains a generic implementation of ResNet adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py (Original author Wei Wu) by Antti-Pekka Hynninen.

Implementing the original resnet ILSVRC 2015 winning network from Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Deep Residual Learning for Image Recognition"

See Also:
  • Method Details

    • residualUnit

      public static ai.djl.nn.Block residualUnit(int numFilters, ai.djl.ndarray.types.Shape stride, boolean dimMatch, boolean bottleneck, float batchNormMomentum)
      Builds a Block that represents a residual unit used in the implementation of the Resnet model.
      Parameters:
      numFilters - the number of output channels
      stride - the stride of the convolution in each dimension
      dimMatch - whether the number of channels between input and output has to remain the same
      bottleneck - whether to use bottleneck architecture
      batchNormMomentum - the momentum to be used for BatchNorm
      Returns:
      a Block that represents a residual unit
    • resnet

      public static ai.djl.nn.SequentialBlock resnet(ResNetV1.Builder builder)
      Creates a new Block of ResNetV1 with the arguments from the given ResNetV1.Builder.
      Parameters:
      builder - the ResNetV1.Builder with the necessary arguments
      Returns:
      a Block that represents the required ResNet model
    • builder

      public static ResNetV1.Builder builder()
      Creates a builder to build a ResNetV1.
      Returns:
      a new builder