Package ai.djl.basicmodelzoo.tabular
Class TabNet.DecisionStep
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.basicmodelzoo.tabular.TabNet.DecisionStep
-
- All Implemented Interfaces:
ai.djl.nn.Block
- Enclosing class:
- TabNet
public static final class TabNet.DecisionStep extends ai.djl.nn.AbstractBlockDecisionStep is just combining featureTransformer and attentionTransformer together.
-
-
Constructor Summary
Constructors Constructor Description DecisionStep(int inputDim, int numD, int numA, java.util.List<ai.djl.nn.Block> shared, int nInd, int virtualBatchSize, float batchNormMomentum)Creates aTabNet.DecisionStepwith given parameters.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description protected ai.djl.ndarray.NDListforwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)ai.djl.ndarray.types.Shape[]getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)protected voidinitializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
-
-
-
-
Constructor Detail
-
DecisionStep
public DecisionStep(int inputDim, int numD, int numA, java.util.List<ai.djl.nn.Block> shared, int nInd, int virtualBatchSize, float batchNormMomentum)Creates aTabNet.DecisionStepwith given parameters.- Parameters:
inputDim- the number of input dimension for attentionTransformernumD- the number of dimension except attentionTransformernumA- the number of dimension for attentionTransformershared- the shared fullyConnected layersnInd- the number of independent fullyConnected layersvirtualBatchSize- the virtual batch sizebatchNormMomentum- the momentum for batchNorm layer
-
-
Method Detail
-
forwardInternal
protected ai.djl.ndarray.NDList forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)- Specified by:
forwardInternalin classai.djl.nn.AbstractBaseBlock
-
getOutputShapes
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
-
initializeChildBlocks
protected void initializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)- Overrides:
initializeChildBlocksin classai.djl.nn.AbstractBaseBlock
-
-