Package ai.djl.training.listener
Class SaveModelTrainingListener
- java.lang.Object
-
- ai.djl.training.listener.TrainingListenerAdapter
-
- ai.djl.training.listener.SaveModelTrainingListener
-
- All Implemented Interfaces:
TrainingListener
public class SaveModelTrainingListener extends TrainingListenerAdapter
ATrainingListenerthat saves a model and can save checkpoints.
-
-
Nested Class Summary
-
Nested classes/interfaces inherited from interface ai.djl.training.listener.TrainingListener
TrainingListener.BatchData, TrainingListener.Defaults
-
-
Constructor Summary
Constructors Constructor Description SaveModelTrainingListener(java.lang.String outputDir)Constructs aSaveModelTrainingListenerusing the model's name.SaveModelTrainingListener(java.lang.String outputDir, java.lang.String overrideModelName)Constructs aSaveModelTrainingListener.SaveModelTrainingListener(java.lang.String outputDir, java.lang.String overrideModelName, int checkpoint)Constructs aSaveModelTrainingListener.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description intgetCheckpoint()Returns the checkpoint frequency (or -1 for no checkpointing) inSaveModelTrainingListener.java.lang.StringgetOverrideModelName()Returns the override model name to save checkpoints with.voidonEpoch(Trainer trainer)Listens to the end of an epoch during training.voidonTrainingEnd(Trainer trainer)Listens to the end of training.protected voidsaveModel(Trainer trainer)voidsetCheckpoint(int checkpoint)Sets the checkpoint frequency inSaveModelTrainingListener.voidsetOverrideModelName(java.lang.String overrideModelName)Sets the override model name to save checkpoints with.voidsetSaveModelCallback(java.util.function.Consumer<Trainer> onSaveModel)Sets the callback function on model saving.-
Methods inherited from class ai.djl.training.listener.TrainingListenerAdapter
onTrainingBatch, onTrainingBegin, onValidationBatch
-
-
-
-
Constructor Detail
-
SaveModelTrainingListener
public SaveModelTrainingListener(java.lang.String outputDir)
Constructs aSaveModelTrainingListenerusing the model's name.- Parameters:
outputDir- the directory to output the checkpointed models in
-
SaveModelTrainingListener
public SaveModelTrainingListener(java.lang.String outputDir, java.lang.String overrideModelName)Constructs aSaveModelTrainingListener.- Parameters:
overrideModelName- an override model name to save checkpoints withoutputDir- the directory to output the checkpointed models in
-
SaveModelTrainingListener
public SaveModelTrainingListener(java.lang.String outputDir, java.lang.String overrideModelName, int checkpoint)Constructs aSaveModelTrainingListener.- Parameters:
overrideModelName- an override model name to save checkpoints withoutputDir- the directory to output the checkpointed models incheckpoint- adds a checkpoint every n epochs
-
-
Method Detail
-
onEpoch
public void onEpoch(Trainer trainer)
Listens to the end of an epoch during training.- Specified by:
onEpochin interfaceTrainingListener- Overrides:
onEpochin classTrainingListenerAdapter- Parameters:
trainer- the trainer the listener is attached to
-
onTrainingEnd
public void onTrainingEnd(Trainer trainer)
Listens to the end of training.- Specified by:
onTrainingEndin interfaceTrainingListener- Overrides:
onTrainingEndin classTrainingListenerAdapter- Parameters:
trainer- the trainer the listener is attached to
-
getOverrideModelName
public java.lang.String getOverrideModelName()
Returns the override model name to save checkpoints with.- Returns:
- the override model name to save checkpoints with
-
setOverrideModelName
public void setOverrideModelName(java.lang.String overrideModelName)
Sets the override model name to save checkpoints with.- Parameters:
overrideModelName- the override model name to save checkpoints with
-
getCheckpoint
public int getCheckpoint()
Returns the checkpoint frequency (or -1 for no checkpointing) inSaveModelTrainingListener.- Returns:
- the checkpoint frequency (or -1 for no checkpointing)
-
setCheckpoint
public void setCheckpoint(int checkpoint)
Sets the checkpoint frequency inSaveModelTrainingListener.- Parameters:
checkpoint- how many epochs between checkpoints (or -1 for no checkpoints)
-
setSaveModelCallback
public void setSaveModelCallback(java.util.function.Consumer<Trainer> onSaveModel)
Sets the callback function on model saving.This allows user to set custom properties to model metadata.
- Parameters:
onSaveModel- the callback function on model saving
-
saveModel
protected void saveModel(Trainer trainer)
-
-