public class SaveModelTrainingListener extends TrainingListenerAdapter
TrainingListener that saves a model and can save checkpoints.TrainingListener.BatchData, TrainingListener.Defaults| Constructor and Description |
|---|
SaveModelTrainingListener(java.lang.String outputDir)
Constructs a
SaveModelTrainingListener using the model's name. |
SaveModelTrainingListener(java.lang.String outputDir,
java.lang.String overrideModelName)
Constructs a
SaveModelTrainingListener. |
SaveModelTrainingListener(java.lang.String outputDir,
java.lang.String overrideModelName,
int checkpoint)
Constructs a
SaveModelTrainingListener. |
| Modifier and Type | Method and Description |
|---|---|
int |
getCheckpoint()
Returns the checkpoint frequency (or -1 for no checkpointing) in
SaveModelTrainingListener. |
java.lang.String |
getOverrideModelName()
Returns the override model name to save checkpoints with.
|
void |
onEpoch(Trainer trainer)
Listens to the end of an epoch during training.
|
void |
onTrainingEnd(Trainer trainer)
Listens to the end of training.
|
protected void |
saveModel(Trainer trainer) |
void |
setCheckpoint(int checkpoint)
Sets the checkpoint frequency in
SaveModelTrainingListener. |
void |
setOverrideModelName(java.lang.String overrideModelName)
Sets the override model name to save checkpoints with.
|
void |
setSaveModelCallback(java.util.function.Consumer<Trainer> onSaveModel)
Sets the callback function on model saving.
|
onTrainingBatch, onTrainingBegin, onValidationBatchpublic SaveModelTrainingListener(java.lang.String outputDir)
SaveModelTrainingListener using the model's name.outputDir - the directory to output the checkpointed models inpublic SaveModelTrainingListener(java.lang.String outputDir,
java.lang.String overrideModelName)
SaveModelTrainingListener.overrideModelName - an override model name to save checkpoints withoutputDir - the directory to output the checkpointed models inpublic SaveModelTrainingListener(java.lang.String outputDir,
java.lang.String overrideModelName,
int checkpoint)
SaveModelTrainingListener.overrideModelName - an override model name to save checkpoints withoutputDir - the directory to output the checkpointed models incheckpoint - adds a checkpoint every n epochspublic void onEpoch(Trainer trainer)
onEpoch in interface TrainingListeneronEpoch in class TrainingListenerAdaptertrainer - the trainer the listener is attached topublic void onTrainingEnd(Trainer trainer)
onTrainingEnd in interface TrainingListeneronTrainingEnd in class TrainingListenerAdaptertrainer - the trainer the listener is attached topublic java.lang.String getOverrideModelName()
public void setOverrideModelName(java.lang.String overrideModelName)
overrideModelName - the override model name to save checkpoints withpublic int getCheckpoint()
SaveModelTrainingListener.public void setCheckpoint(int checkpoint)
SaveModelTrainingListener.checkpoint - how many epochs between checkpoints (or -1 for no checkpoints)public void setSaveModelCallback(java.util.function.Consumer<Trainer> onSaveModel)
This allows user to set custom properties to model metadata.
onSaveModel - the callback function on model savingprotected void saveModel(Trainer trainer)