Package ai.djl.basicmodelzoo.tabular
Class TabNet
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.basicmodelzoo.tabular.TabNet
-
- All Implemented Interfaces:
ai.djl.nn.Block
public final class TabNet extends ai.djl.nn.AbstractBlockTabNetcontains a generic implementation of TabNet adapted from https://towardsdatascience.com/implementing-tabnet-in-pytorch-fc977c383279 (Original author Samrat Thapa)TabNet is a neural architecture for tabular dataset developed by the research team at Google Cloud AI. It was able to achieve state_of_the_art results on several datasets in both regression and classification problems. Another desirable feature of TabNet is interpretability. Contrary to most of deep learning, where the neural networks act like black boxes, we can interpret which features the models selects in case of TabNet.
see https://arxiv.org/pdf/1908.07442.pdf for more information about TabNet
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classTabNet.AttentionTransformerAttentionTransformer is where the tabNet models learn the relationship between relevant features, and decides which features to pass on to the feature transformer of the current decision step.static classTabNet.BuilderThe Builder to construct aTabNetobject.static classTabNet.DecisionStepDecisionStep is just combining featureTransformer and attentionTransformer together.
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description static TabNet.Builderbuilder()Creates a builder to build aTabNet.static ai.djl.nn.BlockfeatureTransformer(java.util.List<ai.djl.nn.Block> sharedBlocks, int outDim, int numIndependent, int virtualBatchSize, float batchNormMomentum)Creates a featureTransformer Block.protected ai.djl.ndarray.NDListforwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)ai.djl.ndarray.types.Shape[]getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)static ai.djl.nn.BlockgluBlock(ai.djl.nn.Block sharedBlock, int outDim, int virtualBatchSize, float batchNormMomentum)Creates a FC-BN-GLU block used in tabNet.protected voidinitializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)static ai.djl.ndarray.NDArraytabNetGLU(ai.djl.ndarray.NDArray array, int units)Applies tabNetGLU activation(which is mostly used in tabNet) on the inputNDArray.static ai.djl.ndarray.NDListtabNetGLU(ai.djl.ndarray.NDList arrays, int units)Applies tabNetGLU activation(which is mostly used in tabNet) on the input singletonNDList.static ai.djl.nn.BlocktabNetGLUBlock(int units)Creates aLambdaBlockthat applies thetabNetGLU(NDArray, int)activation function in its forward function.-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
-
-
-
-
Method Detail
-
tabNetGLU
public static ai.djl.ndarray.NDArray tabNetGLU(ai.djl.ndarray.NDArray array, int units)Applies tabNetGLU activation(which is mostly used in tabNet) on the inputNDArray.- Parameters:
array- the inputNDArrayunits- the half number of the resultant features- Returns:
- the
NDArrayafter applying tabNetGLU function
-
tabNetGLU
public static ai.djl.ndarray.NDList tabNetGLU(ai.djl.ndarray.NDList arrays, int units)Applies tabNetGLU activation(which is mostly used in tabNet) on the input singletonNDList.- Parameters:
arrays- the input singletonNDListunits- the half number of the resultant features- Returns:
- the singleton
NDListafter applying tabNetGLU function
-
tabNetGLUBlock
public static ai.djl.nn.Block tabNetGLUBlock(int units)
Creates aLambdaBlockthat applies thetabNetGLU(NDArray, int)activation function in its forward function.- Parameters:
units- the half number of feature- Returns:
LambdaBlockthat applies thetabNetGLU(NDArray, int)activation function
-
gluBlock
public static ai.djl.nn.Block gluBlock(ai.djl.nn.Block sharedBlock, int outDim, int virtualBatchSize, float batchNormMomentum)Creates a FC-BN-GLU block used in tabNet. In order to do GLU, we double the dimension of the input features to the GLU using a fc layer.- Parameters:
sharedBlock- the shared fully connected layeroutDim- the output feature dimensionvirtualBatchSize- the virtualBatchSizebatchNormMomentum- the momentum used for ghost batchNorm layer- Returns:
- a FC-BN-GLU block
-
featureTransformer
public static ai.djl.nn.Block featureTransformer(java.util.List<ai.djl.nn.Block> sharedBlocks, int outDim, int numIndependent, int virtualBatchSize, float batchNormMomentum)Creates a featureTransformer Block. The feature transformer is where all the selected features are processed to generate the final output.- Parameters:
sharedBlocks- the sharedBlocks of feature transformeroutDim- the output dimension of feature transformernumIndependent- the number of independent blocks of feature transformervirtualBatchSize- the virtual batch size for ghost batch normbatchNormMomentum- the momentum for batch norm layer- Returns:
- a feature transformer
-
forwardInternal
protected ai.djl.ndarray.NDList forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)- Specified by:
forwardInternalin classai.djl.nn.AbstractBaseBlock
-
getOutputShapes
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
-
initializeChildBlocks
protected void initializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)- Overrides:
initializeChildBlocksin classai.djl.nn.AbstractBaseBlock
-
builder
public static TabNet.Builder builder()
Creates a builder to build aTabNet.- Returns:
- a new builder
-
-