public interface Trainer
extends java.lang.AutoCloseable
Trainer provides an easy, and manageable interface for training. Trainer is
not thread-safe.
See the tutorials on:
| Modifier and Type | Method and Description |
|---|---|
void |
close() |
NDList |
forward(NDList input)
Applies the forward function of the model once on the given input
NDList. |
Loss |
getLoss()
Gets the training
Loss function of the trainer. |
NDManager |
getManager()
Gets the
NDManager from the model. |
Model |
getModel()
Returns the model used to create this trainer.
|
<T extends TrainingMetric> |
getTrainingMetric(java.lang.Class<T> clazz)
Gets the training
TrainingMetric that is an instance of the given Class. |
Loss |
getValidationLoss()
Gets the validation
Loss function of the trainer. |
<T extends TrainingMetric> |
getValidationMetric(java.lang.Class<T> clazz)
Gets the validation
TrainingMetric that is an instance of the given Class. |
void |
initialize(Shape... shapes)
Initializes the
Model that the Trainer is going to train. |
default java.lang.Iterable<Batch> |
iterateDataset(Dataset dataset)
Fetches an iterator that can iterate through the given
Dataset. |
GradientCollector |
newGradientCollector()
Returns a new instance of
GradientCollector. |
void |
resetTrainingMetrics()
Resets each of the training metrics and loss to its respective initial value.
|
void |
setMetrics(Metrics metrics)
Attaches a Metrics param to use for benchmark.
|
void |
setTrainingListener(TrainingListener listener)
Sets a
TrainingListener to the Trainer. |
void |
step()
Updates all of the parameters of the model once.
|
void |
trainBatch(Batch batch)
Trains the model with one iteration of the given
Batch of data. |
void |
validateBatch(Batch batch)
Validates the given batch of data.
|
void initialize(Shape... shapes)
Model that the Trainer is going to train.shapes - an array of Shape of the inputsdefault java.lang.Iterable<Batch> iterateDataset(Dataset dataset)
Dataset.dataset - the dataset to iterate throughIterable of Batch that contains batches of data from the datasetGradientCollector newGradientCollector()
GradientCollector.GradientCollectorvoid trainBatch(Batch batch)
Batch of data.batch - a Batch that contains data, and its respective labelsNDList forward(NDList input)
NDList.input - the input NDListvoid validateBatch(Batch batch)
During validation, the loss and training metrics are computed, but gradients aren't computed, and parameters aren't updated.
batch - a Batch of datavoid step()
void setMetrics(Metrics metrics)
metrics - the Metrics classvoid setTrainingListener(TrainingListener listener)
TrainingListener to the Trainer.listener - the TrainingListener to be setvoid resetTrainingMetrics()
Loss getValidationLoss()
Loss function of the trainer.Loss functionModel getModel()
<T extends TrainingMetric> T getTrainingMetric(java.lang.Class<T> clazz)
TrainingMetric that is an instance of the given Class.T - the type of the training metricclazz - the Class of the TrainingMetric sought<T extends TrainingMetric> T getValidationMetric(java.lang.Class<T> clazz)
TrainingMetric that is an instance of the given Class.T - the type of the validation metricclazz - the Class of the TrainingMetric soughtvoid close()
close in interface java.lang.AutoCloseable