Package ai.djl.modality.nlp.embedding
Class TrainableTextEmbedding
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.modality.nlp.embedding.TrainableTextEmbedding
-
- All Implemented Interfaces:
TextEmbedding,Block
public class TrainableTextEmbedding extends AbstractBlock implements TextEmbedding
TrainableTextEmbeddingis an implementation ofTextEmbeddingbased onTrainableWordEmbeddingblock. ThisTextEmbeddingis ideal when there are no pre-trained embeddings available, or when the pre-trained embedding needs to be further trained.
-
-
Field Summary
-
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
-
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
-
Constructor Summary
Constructors Constructor Description TrainableTextEmbedding(TrainableWordEmbedding wordEmbedding)Constructs aTrainableTextEmbedding.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NDArrayembedText(NDArray textIndices)Embeds the text after preprocessed usingTextEmbedding.preprocessTextToEmbed(List).protected NDListforwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)after initialization.Shape[]getOutputShapes(Shape[] inputShapes)Returns the expected output shapes of the block for the specified input shapes.voidinitializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)Initializes the Child blocks of this block.long[]preprocessTextToEmbed(java.util.List<java.lang.String> text)Preprocesses the text to embed into an array to pass into the model.java.util.List<java.lang.String>unembedText(NDArray textEmbedding)Returns the closest matching text for a given embedding.-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
Methods inherited from interface ai.djl.modality.nlp.embedding.TextEmbedding
embedText, embedText
-
-
-
-
Constructor Detail
-
TrainableTextEmbedding
public TrainableTextEmbedding(TrainableWordEmbedding wordEmbedding)
Constructs aTrainableTextEmbedding.- Parameters:
wordEmbedding- the word embedding to embed each word
-
-
Method Detail
-
preprocessTextToEmbed
public long[] preprocessTextToEmbed(java.util.List<java.lang.String> text)
Preprocesses the text to embed into an array to pass into the model.Make sure to call
TextEmbedding.embedText(NDManager, long[])after this.- Specified by:
preprocessTextToEmbedin interfaceTextEmbedding- Parameters:
text- the text to embed- Returns:
- the indices of text that is ready to embed
-
embedText
public NDArray embedText(NDArray textIndices) throws EmbeddingException
Embeds the text after preprocessed usingTextEmbedding.preprocessTextToEmbed(List).- Specified by:
embedTextin interfaceTextEmbedding- Parameters:
textIndices- the indices of text to embed- Returns:
- the embedded text
- Throws:
EmbeddingException- if there is an error while trying to embed
-
unembedText
public java.util.List<java.lang.String> unembedText(NDArray textEmbedding)
Returns the closest matching text for a given embedding.- Specified by:
unembedTextin interfaceTextEmbedding- Parameters:
textEmbedding- the text embedding to find the matching string text for.- Returns:
- text similar to the passed in embedding
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)after initialization.- Specified by:
forwardInternalin classAbstractBaseBlock- Parameters:
parameterStore- the parameter storeinputs- the input NDListtraining- true for a training forward passparams- optional parameters- Returns:
- the output of the forward pass
-
initializeChildBlocks
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.- Overrides:
initializeChildBlocksin classAbstractBaseBlock- Parameters:
manager- the manager to use for initializationdataType- the requested data typeinputShapes- the expected input shapes for this block
-
getOutputShapes
public Shape[] getOutputShapes(Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.- Specified by:
getOutputShapesin interfaceBlock- Parameters:
inputShapes- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
-