Class BaseFunction<T extends Tensor>

java.lang.Object
deepboof.impl.forward.standard.BaseFunction<T>
All Implemented Interfaces:
Function<T>
Direct Known Subclasses:
BaseSpatialWindow, ElementWiseFunction, FunctionBatchNorm_F32, FunctionBatchNorm_F64, FunctionLinear_F32, FunctionLinear_F64

public abstract class BaseFunction<T extends Tensor>
extends java.lang.Object
implements Function<T>
Base class which implements common functionality between all functions
  • Field Summary

    Fields 
    Modifier and Type Field Description
    protected int miniBatchSize
    Number of inputs in the mini-batch
    protected java.util.List<T> parameters  
    protected int[] shapeInput  
    protected int[] shapeOutput  
    protected java.util.List<int[]> shapeParameters  
  • Constructor Summary

    Constructors 
    Constructor Description
    BaseFunction()  
  • Method Summary

    Modifier and Type Method Description
    abstract void _forward​(T input, T output)  
    abstract void _initialize()  
    abstract void _setParameters​(java.util.List<T> parameters)  
    void forward​(T input, T output)
    Performs forward pass of the function on the provided inputs.
    int[] getOutputShape()
    Returns the output tensor's shape, without the mini-batch dimension.
    java.util.List<T> getParameters()
    If the parameters have been set, then this returns the list of parameters.
    java.util.List<int[]> getParameterShapes()
    Returns the shape of input tensors, without the mini-batch dimension.
    void initialize​(int... shapeInput)
    Initializes internal data structures given the shape of the input tensor, minus the stacked input dimension.
    void setParameters​(java.util.List<T> parameters)
    Specifies learnable function parameters, e.g.

    Methods inherited from class java.lang.Object

    clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait

    Methods inherited from interface deepboof.Function

    getTensorType
  • Field Details

  • Constructor Details

  • Method Details

    • initialize

      public void initialize​(int... shapeInput)
      Description copied from interface: Function
      Initializes internal data structures given the shape of the input tensor, minus the stacked input dimension. For example, an input tensor of shape (B,C,D) might be passed into initialize, while the actual input is (N,B,C,D). N is the number of stacked inputs and is allowed to vary after initialization.
      Specified by:
      initialize in interface Function<T extends Tensor>
      Parameters:
      shapeInput - Shape of the input tensor
    • _initialize

      public abstract void _initialize()
    • setParameters

      public void setParameters​(java.util.List<T> parameters)
      Description copied from interface: Function

      Specifies learnable function parameters, e.g. weights for linear functions. This function only needs to be called once each time a parameter has been modified. Must be called before Function.forward(T, T).

      NOTE: Reference to the parameters may be saved internally and the tensors should not be modified externally.
      Specified by:
      setParameters in interface Function<T extends Tensor>
      Parameters:
      parameters - Tensors containing parameters which are optimized. Not modified.
    • _setParameters

      public abstract void _setParameters​(java.util.List<T> parameters)
    • getParameters

      public java.util.List<T> getParameters()
      Description copied from interface: Function
      If the parameters have been set, then this returns the list of parameters. Otherwise null is returned.
      Specified by:
      getParameters in interface Function<T extends Tensor>
      Returns:
      List of parameters or null if they have not been set yet
    • forward

      public void forward​(T input, T output)
      Description copied from interface: Function
      Performs forward pass of the function on the provided inputs.
       Input tensor shape = (N,variable ... )
       - N is the mini-batch size
       - Other dimensions are implementation specific.
       
      Specified by:
      forward in interface Function<T extends Tensor>
      Parameters:
      input - Input to the function.
      output - Output tensor. Modified.
    • _forward

      public abstract void _forward​(T input, T output)
    • getParameterShapes

      public java.util.List<int[]> getParameterShapes()
      Description copied from interface: Function
      Returns the shape of input tensors, without the mini-batch dimension. Only valid after Function.initialize(int...) has been called.
      Specified by:
      getParameterShapes in interface Function<T extends Tensor>
      Returns:
      Expected shapes of input tensors. This data structure may be recycled and is modified on the next call to Function.initialize(int...).
    • getOutputShape

      public int[] getOutputShape()
      Description copied from interface: Function
      Returns the output tensor's shape, without the mini-batch dimension. Only valid after Function.initialize(int...) has been called.
      Specified by:
      getOutputShape in interface Function<T extends Tensor>
      Returns:
      Expected shape of output tensor. This data structure may be recycled and is modified on the next call to Function.initialize(int...).