Package ai.djl.nn.core
Class ConstantEmbedding
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.nn.core.ConstantEmbedding
-
- All Implemented Interfaces:
Block,AbstractEmbedding,AbstractIndexedEmbedding
public class ConstantEmbedding extends AbstractBlock implements AbstractIndexedEmbedding
AnAbstractIndexedEmbeddingthat always returns a constant value.
-
-
Field Summary
Fields Modifier and Type Field Description protected NDArrayembedding-
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
-
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
-
Constructor Summary
Constructors Constructor Description ConstantEmbedding(NDArray embedding)Constructs a constant embedding with the given constant.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description java.lang.Objectdecode(byte[] byteArray)Decodes the given byte array into an object of input parameter type.NDArrayembed(NDManager manager, java.lang.Object[] items)Embeds an array of items.longembed(java.lang.Object item)Embeds an item.byte[]encode(java.lang.Object input)Encodes an object of input type into a byte array.protected NDListforwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)after initialization.Shape[]getOutputShapes(Shape[] inputShapes)Returns the expected output shapes of the block for the specified input shapes.booleanhasItem(java.lang.Object item)Returns whether an item is in the embedding.voidloadParameters(NDManager manager, java.io.DataInputStream is)Loads the parameters from the given input stream.voidsaveParameters(java.io.DataOutputStream os)Writes the parameters of the block to the given outputStream.java.util.Optional<?>unembed(long index)Returns the item corresponding to the given index.-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, prepare, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
-
-
-
Field Detail
-
embedding
protected NDArray embedding
-
-
Constructor Detail
-
ConstantEmbedding
public ConstantEmbedding(NDArray embedding)
Constructs a constant embedding with the given constant.The constant is assumed to be a fixed value, and starts out as frozen. To unfreeze, use
Block.freezeParameters(boolean).- Parameters:
embedding- the value to return for all embeddings
-
-
Method Detail
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)after initialization.- Specified by:
forwardInternalin classAbstractBaseBlock- Parameters:
parameterStore- the parameter storeinputs- the input NDListtraining- true for a training forward passparams- optional parameters- Returns:
- the output of the forward pass
-
getOutputShapes
public Shape[] getOutputShapes(Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.- Specified by:
getOutputShapesin interfaceBlock- Parameters:
inputShapes- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
saveParameters
public void saveParameters(java.io.DataOutputStream os)
Writes the parameters of the block to the given outputStream.- Specified by:
saveParametersin interfaceBlock- Overrides:
saveParametersin classAbstractBaseBlock- Parameters:
os- the outputstream to save the parameters to
-
loadParameters
public void loadParameters(NDManager manager, java.io.DataInputStream is)
Loads the parameters from the given input stream.- Specified by:
loadParametersin interfaceBlock- Overrides:
loadParametersin classAbstractBaseBlock- Parameters:
manager- an NDManager to create the parameter arraysis- the inputstream that stream the parameter values
-
unembed
public java.util.Optional<?> unembed(long index)
Returns the item corresponding to the given index.- Specified by:
unembedin interfaceAbstractIndexedEmbedding- Parameters:
index- the index- Returns:
- the item corresponding to the given index
-
encode
public byte[] encode(java.lang.Object input)
Encodes an object of input type into a byte array. This is used in saving and loading theEmbeddingobjects.- Specified by:
encodein interfaceAbstractIndexedEmbedding- Parameters:
input- the input object to be encoded- Returns:
- the encoded byte array.
-
decode
public java.lang.Object decode(byte[] byteArray)
Decodes the given byte array into an object of input parameter type.- Specified by:
decodein interfaceAbstractIndexedEmbedding- Parameters:
byteArray- the byte array to be decoded- Returns:
- the decode object of input parameter type
-
embed
public long embed(java.lang.Object item)
Embeds an item.- Specified by:
embedin interfaceAbstractIndexedEmbedding- Parameters:
item- the item to embed- Returns:
- the index of the item in the embedding
-
embed
public NDArray embed(NDManager manager, java.lang.Object[] items)
Embeds an array of items.- Specified by:
embedin interfaceAbstractEmbedding- Parameters:
manager- the manager for the new embeddingsitems- the items to embed- Returns:
- the embedding
NDArrayof Shape(items.length, embeddingSize)
-
hasItem
public boolean hasItem(java.lang.Object item)
Returns whether an item is in the embedding.- Specified by:
hasItemin interfaceAbstractEmbedding- Parameters:
item- the item to test- Returns:
- true if the item is in the embedding
-
-