Package ai.djl.modality.nlp.embedding
Class TrainableWordEmbedding
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.nn.core.Embedding<java.lang.String>
-
- ai.djl.modality.nlp.embedding.TrainableWordEmbedding
-
- All Implemented Interfaces:
WordEmbedding,Block,AbstractEmbedding<java.lang.String>,AbstractIndexedEmbedding<java.lang.String>
public class TrainableWordEmbedding extends Embedding<java.lang.String> implements WordEmbedding
TrainableWordEmbeddingis an implementation ofWordEmbeddingandEmbeddingbased on aDefaultVocabulary. ThisWordEmbeddingis ideal when there are no pre-trained embeddings available.
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classTrainableWordEmbedding.BuilderA builder for aTrainableWordEmbedding.-
Nested classes/interfaces inherited from class ai.djl.nn.core.Embedding
Embedding.BaseBuilder<T,B extends Embedding.BaseBuilder<T,B>>, Embedding.DefaultEmbedding, Embedding.DefaultItem
-
-
Field Summary
-
Fields inherited from class ai.djl.nn.core.Embedding
embedding, embeddingSize, fallthroughEmbedding, numEmbeddings, sparseFormat
-
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
-
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
-
Constructor Summary
Constructors Constructor Description TrainableWordEmbedding(TrainableWordEmbedding.Builder builder)Constructs a new instance ofTrainableWordEmbeddingfrom theTrainableWordEmbedding.Builder.TrainableWordEmbedding(Vocabulary vocabulary, int embeddingSize)Constructs a new instance ofTrainableWordEmbeddingfrom aDefaultVocabularyand a given embedding size.
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description static TrainableWordEmbedding.Builderbuilder()Creates a builder to build anEmbedding.java.lang.Stringdecode(byte[] byteArray)Decodes the given byte array into an object of input parameter type.longembed(java.lang.String item)Embeds an item.NDArrayembedWord(NDArray index)Embeds the word after preprocessed usingWordEmbedding.preprocessWordToEmbed(String).byte[]encode(java.lang.String input)Encodes an object of input type into a byte array.static TrainableWordEmbeddingfromPretrained(NDArray embedding, java.util.List<java.lang.String> items)Constructs a pretrained embedding.static TrainableWordEmbeddingfromPretrained(NDArray embedding, java.util.List<java.lang.String> items, SparseFormat sparseFormat)Constructs a pretrained embedding.booleanhasItem(java.lang.String item)Returns whether an item is in the embedding.longpreprocessWordToEmbed(java.lang.String word)Pre-processes the word to embed into an array to pass into the model.java.util.Optional<java.lang.String>unembed(long index)Returns the item corresponding to the given index.java.lang.StringunembedWord(NDArray word)Returns the closest matching word for the given index.booleanvocabularyContains(java.lang.String word)Returns whether an embedding exists for a word.-
Methods inherited from class ai.djl.nn.core.Embedding
embed, embedding, forwardInternal, getOutputShapes, loadParameters, prepare, saveParameters
-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
Methods inherited from interface ai.djl.modality.nlp.embedding.WordEmbedding
embedWord, embedWord
-
-
-
-
Constructor Detail
-
TrainableWordEmbedding
public TrainableWordEmbedding(TrainableWordEmbedding.Builder builder)
Constructs a new instance ofTrainableWordEmbeddingfrom theTrainableWordEmbedding.Builder.- Parameters:
builder- theTrainableWordEmbedding.Builder
-
TrainableWordEmbedding
public TrainableWordEmbedding(Vocabulary vocabulary, int embeddingSize)
Constructs a new instance ofTrainableWordEmbeddingfrom aDefaultVocabularyand a given embedding size.- Parameters:
vocabulary- aVocabularyto get tokens fromembeddingSize- the required embedding size
-
-
Method Detail
-
fromPretrained
public static TrainableWordEmbedding fromPretrained(NDArray embedding, java.util.List<java.lang.String> items)
Constructs a pretrained embedding.Because it is created with preTrained data, it is created as a frozen block. If you with to update it, call
Block.freezeParameters(boolean).- Parameters:
embedding- the embedding arrayitems- the items in the embedding (in matching order to the embedding array)- Returns:
- the created embedding
-
fromPretrained
public static TrainableWordEmbedding fromPretrained(NDArray embedding, java.util.List<java.lang.String> items, SparseFormat sparseFormat)
Constructs a pretrained embedding.Because it is created with preTrained data, it is created as a frozen block. If you with to update it, call
Block.freezeParameters(boolean).- Parameters:
embedding- the embedding arrayitems- the items in the embedding (in matching order to the embedding array)sparseFormat- whether to compute row sparse gradient in the backward calculation- Returns:
- the created embedding
-
vocabularyContains
public boolean vocabularyContains(java.lang.String word)
Returns whether an embedding exists for a word.- Specified by:
vocabularyContainsin interfaceWordEmbedding- Parameters:
word- the word to check- Returns:
- true if an embedding exists
-
preprocessWordToEmbed
public long preprocessWordToEmbed(java.lang.String word)
Pre-processes the word to embed into an array to pass into the model.Make sure to call
WordEmbedding.embedWord(NDManager, long)after this.- Specified by:
preprocessWordToEmbedin interfaceWordEmbedding- Parameters:
word- the word to embed- Returns:
- the word that is ready to embed
-
embedWord
public NDArray embedWord(NDArray index) throws EmbeddingException
Embeds the word after preprocessed usingWordEmbedding.preprocessWordToEmbed(String).- Specified by:
embedWordin interfaceWordEmbedding- Parameters:
index- the index of the word to embed- Returns:
- the embedded word
- Throws:
EmbeddingException- if there is an error while trying to embed
-
unembedWord
public java.lang.String unembedWord(NDArray word)
Returns the closest matching word for the given index.- Specified by:
unembedWordin interfaceWordEmbedding- Parameters:
word- the word embedding to find the matching string word for.- Returns:
- a word similar to the passed in embedding
-
encode
public byte[] encode(java.lang.String input)
Encodes an object of input type into a byte array. This is used in saving and loading theEmbeddingobjects.- Specified by:
encodein interfaceAbstractIndexedEmbedding<java.lang.String>- Parameters:
input- the input object to be encoded- Returns:
- the encoded byte array.
-
decode
public java.lang.String decode(byte[] byteArray)
Decodes the given byte array into an object of input parameter type.- Specified by:
decodein interfaceAbstractIndexedEmbedding<java.lang.String>- Parameters:
byteArray- the byte array to be decoded- Returns:
- the decode object of input parameter type
-
embed
public long embed(java.lang.String item)
Embeds an item.- Specified by:
embedin interfaceAbstractIndexedEmbedding<java.lang.String>- Parameters:
item- the item to embed- Returns:
- the index of the item in the embedding
-
unembed
public java.util.Optional<java.lang.String> unembed(long index)
Returns the item corresponding to the given index.- Specified by:
unembedin interfaceAbstractIndexedEmbedding<java.lang.String>- Parameters:
index- the index- Returns:
- the item corresponding to the given index
-
builder
public static TrainableWordEmbedding.Builder builder()
Creates a builder to build anEmbedding.- Returns:
- a new builder
-
hasItem
public boolean hasItem(java.lang.String item)
Returns whether an item is in the embedding.- Specified by:
hasItemin interfaceAbstractEmbedding<java.lang.String>- Parameters:
item- the item to test- Returns:
- true if the item is in the embedding
-
-