public interface Trainer
extends java.lang.AutoCloseable
Trainer interface provides a session for model training.
Trainer provides an easy, and manageable interface for training. Trainer is
not thread-safe.
See the tutorials on:
| Modifier and Type | Method and Description |
|---|---|
void |
close() |
void |
endEpoch()
Runs the end epoch actions.
|
NDList |
forward(NDList input)
Applies the forward function of the model once on the given input
NDList. |
java.util.List<Device> |
getDevices()
Returns the devices used for training.
|
<T extends Evaluator> |
getEvaluator(java.lang.Class<T> clazz)
Gets the
Evaluator that is an instance of the given Class. |
java.util.List<Evaluator> |
getEvaluators()
Gets all
Evaluators. |
Loss |
getLoss()
Gets the training
Loss function of the trainer. |
NDManager |
getManager()
Gets the
NDManager from the model. |
Metrics |
getMetrics()
Returns the Metrics param used for benchmarking.
|
Model |
getModel()
Returns the model used to create this trainer.
|
void |
initialize(Shape... shapes)
Initializes the
Model that the Trainer is going to train. |
default java.lang.Iterable<Batch> |
iterateDataset(Dataset dataset)
Fetches an iterator that can iterate through the given
Dataset. |
GradientCollector |
newGradientCollector()
Returns a new instance of
GradientCollector. |
void |
setMetrics(Metrics metrics)
Attaches a Metrics param to use for benchmarking.
|
void |
step()
Updates all of the parameters of the model once.
|
void |
trainBatch(Batch batch)
Trains the model with one iteration of the given
Batch of data. |
void |
validateBatch(Batch batch)
Validates the given batch of data.
|
void initialize(Shape... shapes)
Model that the Trainer is going to train.shapes - an array of Shape of the inputsdefault java.lang.Iterable<Batch> iterateDataset(Dataset dataset)
Dataset.dataset - the dataset to iterate throughIterable of Batch that contains batches of data from the datasetGradientCollector newGradientCollector()
GradientCollector.GradientCollectorvoid trainBatch(Batch batch)
Batch of data.batch - a Batch that contains data, and its respective labelsjava.lang.IllegalArgumentException - if the batch engine does not match the trainer engineNDList forward(NDList input)
NDList.input - the input NDListvoid validateBatch(Batch batch)
During validation, the evaluators and losses are computed, but gradients aren't computed, and parameters aren't updated.
batch - a Batch of datajava.lang.IllegalArgumentException - if the batch engine does not match the trainer enginevoid step()
Metrics getMetrics()
void setMetrics(Metrics metrics)
metrics - the Metrics classjava.util.List<Device> getDevices()
void endEpoch()
Model getModel()
java.util.List<Evaluator> getEvaluators()
Evaluators.<T extends Evaluator> T getEvaluator(java.lang.Class<T> clazz)
Evaluator that is an instance of the given Class.T - the type of the training evaluatorclazz - the Class of the Evaluator soughtvoid close()
close in interface java.lang.AutoCloseable