Package ai.djl.basicmodelzoo.tabular
Class TabNet.DecisionStep
java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.basicmodelzoo.tabular.TabNet.DecisionStep
- All Implemented Interfaces:
ai.djl.nn.Block
- Enclosing class:
- TabNet
public static final class TabNet.DecisionStep
extends ai.djl.nn.AbstractBlock
DecisionStep is just combining featureTransformer and attentionTransformer together.
-
Field Summary
Fields inherited from class ai.djl.nn.AbstractBlock
children, parametersFields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version -
Constructor Summary
ConstructorsConstructorDescriptionDecisionStep(int inputDim, int numD, int numA, List<ai.djl.nn.Block> shared, int nInd, int virtualBatchSize, float batchNormMomentum) Creates aTabNet.DecisionStepwith given parameters. -
Method Summary
Modifier and TypeMethodDescriptionprotected ai.djl.ndarray.NDListforwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) ai.djl.ndarray.types.Shape[]getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes) protected voidinitializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes) Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParametersMethods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toStringMethods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitMethods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
Constructor Details
-
DecisionStep
public DecisionStep(int inputDim, int numD, int numA, List<ai.djl.nn.Block> shared, int nInd, int virtualBatchSize, float batchNormMomentum) Creates aTabNet.DecisionStepwith given parameters.- Parameters:
inputDim- the number of input dimension for attentionTransformernumD- the number of dimension except attentionTransformernumA- the number of dimension for attentionTransformershared- the shared fullyConnected layersnInd- the number of independent fullyConnected layersvirtualBatchSize- the virtual batch sizebatchNormMomentum- the momentum for batchNorm layer
-
-
Method Details
-
forwardInternal
protected ai.djl.ndarray.NDList forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) - Specified by:
forwardInternalin classai.djl.nn.AbstractBaseBlock
-
getOutputShapes
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes) -
initializeChildBlocks
protected void initializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes) - Overrides:
initializeChildBlocksin classai.djl.nn.AbstractBaseBlock
-