Package deepboof.impl.forward.standard
Class BaseFunction<T extends Tensor>
java.lang.Object
deepboof.impl.forward.standard.BaseFunction<T>
- All Implemented Interfaces:
Function<T>
- Direct Known Subclasses:
BaseSpatialWindow,ElementWiseFunction,FunctionBatchNorm_F32,FunctionBatchNorm_F64,FunctionLinear_F32,FunctionLinear_F64
public abstract class BaseFunction<T extends Tensor> extends java.lang.Object implements Function<T>
Base class which implements common functionality between all
functions-
Field Summary
Fields Modifier and Type Field Description protected intminiBatchSizeNumber of inputs in the mini-batchprotected java.util.List<T>parametersprotected int[]shapeInputprotected int[]shapeOutputprotected java.util.List<int[]>shapeParameters -
Constructor Summary
Constructors Constructor Description BaseFunction() -
Method Summary
Modifier and Type Method Description abstract void_forward(T input, T output)abstract void_initialize()abstract void_setParameters(java.util.List<T> parameters)voidforward(T input, T output)Performs forward pass of the function on the provided inputs.int[]getOutputShape()Returns the output tensor's shape, without the mini-batch dimension.java.util.List<T>getParameters()If the parameters have been set, then this returns the list of parameters.java.util.List<int[]>getParameterShapes()Returns the shape of input tensors, without the mini-batch dimension.voidinitialize(int... shapeInput)Initializes internal data structures given the shape of the input tensor, minus the stacked input dimension.voidsetParameters(java.util.List<T> parameters)Specifies learnable function parameters, e.g.
-
Field Details
-
shapeInput
protected int[] shapeInput -
shapeParameters
protected java.util.List<int[]> shapeParameters -
shapeOutput
protected int[] shapeOutput -
parameters
-
miniBatchSize
protected int miniBatchSizeNumber of inputs in the mini-batch
-
-
Constructor Details
-
BaseFunction
public BaseFunction()
-
-
Method Details
-
initialize
public void initialize(int... shapeInput)Description copied from interface:FunctionInitializes internal data structures given the shape of the input tensor, minus the stacked input dimension. For example, an input tensor of shape (B,C,D) might be passed into initialize, while the actual input is (N,B,C,D). N is the number of stacked inputs and is allowed to vary after initialization.- Specified by:
initializein interfaceFunction<T extends Tensor>- Parameters:
shapeInput- Shape of the input tensor
-
_initialize
public abstract void _initialize() -
setParameters
Description copied from interface:FunctionSpecifies learnable function parameters, e.g. weights for linear functions. This function only needs to be called once each time a parameter has been modified. Must be called before
NOTE: Reference to the parameters may be saved internally and the tensors should not be modified externally.Function.forward(T, T).- Specified by:
setParametersin interfaceFunction<T extends Tensor>- Parameters:
parameters- Tensors containing parameters which are optimized. Not modified.
-
_setParameters
-
getParameters
Description copied from interface:FunctionIf the parameters have been set, then this returns the list of parameters. Otherwise null is returned.- Specified by:
getParametersin interfaceFunction<T extends Tensor>- Returns:
- List of parameters or null if they have not been set yet
-
forward
Description copied from interface:FunctionPerforms forward pass of the function on the provided inputs.Input tensor shape = (N,variable ... ) - N is the mini-batch size - Other dimensions are implementation specific.
-
_forward
-
getParameterShapes
public java.util.List<int[]> getParameterShapes()Description copied from interface:FunctionReturns the shape of input tensors, without the mini-batch dimension. Only valid afterFunction.initialize(int...)has been called.- Specified by:
getParameterShapesin interfaceFunction<T extends Tensor>- Returns:
- Expected shapes of input tensors. This data structure may be recycled and is modified on the next
call to
Function.initialize(int...).
-
getOutputShape
public int[] getOutputShape()Description copied from interface:FunctionReturns the output tensor's shape, without the mini-batch dimension. Only valid afterFunction.initialize(int...)has been called.- Specified by:
getOutputShapein interfaceFunction<T extends Tensor>- Returns:
- Expected shape of output tensor. This data structure may be recycled and is modified on the next
call to
Function.initialize(int...).
-