Package ai.djl.training
Class Trainer
- java.lang.Object
-
- ai.djl.training.Trainer
-
- All Implemented Interfaces:
java.lang.AutoCloseable
public class Trainer extends java.lang.Object implements java.lang.AutoCloseableTheTrainerinterface provides a session for model training.Trainerprovides an easy, and manageable interface for training.Traineris not thread-safe.See the tutorials on:
- See Also:
- The guide on memory management
-
-
Constructor Summary
Constructors Constructor Description Trainer(Model model, TrainingConfig trainingConfig)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description voidaddMetric(java.lang.String metricName, long begin)Helper to add a metric for a time difference.voidclose()NDListevaluate(NDList input)Evaluates function of the model once on the given inputNDList.protected voidfinalize()NDListforward(NDList input)Applies the forward function of the model once on the given inputNDList.NDListforward(NDList data, NDList labels)Applies the forward function of the model once with both data and labels.Device[]getDevices()Returns the devices used for training.java.util.List<Evaluator>getEvaluators()Gets allEvaluators.java.util.Optional<java.util.concurrent.ExecutorService>getExecutorService()Returns theExecutorService.LossgetLoss()Gets the trainingLossfunction of the trainer.NDManagergetManager()Gets theNDManagerfrom the model.MetricsgetMetrics()Returns the Metrics param used for benchmarking.ModelgetModel()Returns the model used to create this trainer.TrainingResultgetTrainingResult()Returns theTrainingResult.voidinitialize(Shape... shapes)Initializes theModelthat theTraineris going to train.java.lang.Iterable<Batch>iterateDataset(Dataset dataset)Fetches an iterator that can iterate through the givenDataset.GradientCollectornewGradientCollector()Returns a new instance ofGradientCollector.voidnotifyListeners(java.util.function.Consumer<TrainingListener> listenerConsumer)Executes a method on each of theTrainingListeners.voidsetMetrics(Metrics metrics)Attaches a Metrics param to use for benchmarking.voidstep()Updates all of the parameters of the model once.
-
-
-
Constructor Detail
-
Trainer
public Trainer(Model model, TrainingConfig trainingConfig)
- Parameters:
model- the model the trainer will train ontrainingConfig- the configuration used by the trainer
-
-
Method Detail
-
initialize
public void initialize(Shape... shapes)
Initializes theModelthat theTraineris going to train.- Parameters:
shapes- an array ofShapeof the inputs
-
iterateDataset
public java.lang.Iterable<Batch> iterateDataset(Dataset dataset) throws java.io.IOException, TranslateException
Fetches an iterator that can iterate through the givenDataset.- Parameters:
dataset- the dataset to iterate through- Returns:
- an
IterableofBatchthat contains batches of data from the dataset - Throws:
java.io.IOException- for various exceptions depending on the datasetTranslateException- if there is an error while processing input
-
newGradientCollector
public GradientCollector newGradientCollector()
Returns a new instance ofGradientCollector.- Returns:
- a new instance of
GradientCollector
-
forward
public NDList forward(NDList input)
Applies the forward function of the model once on the given inputNDList.- Parameters:
input- the inputNDList- Returns:
- the output of the forward function
-
forward
public NDList forward(NDList data, NDList labels)
Applies the forward function of the model once with both data and labels.
-
evaluate
public NDList evaluate(NDList input)
Evaluates function of the model once on the given inputNDList.- Parameters:
input- the inputNDList- Returns:
- the output of the predict function
-
step
public void step()
Updates all of the parameters of the model once.
-
getMetrics
public Metrics getMetrics()
Returns the Metrics param used for benchmarking.- Returns:
- the the Metrics param used for benchmarking
-
setMetrics
public void setMetrics(Metrics metrics)
Attaches a Metrics param to use for benchmarking.- Parameters:
metrics- the Metrics class
-
getDevices
public Device[] getDevices()
Returns the devices used for training.- Returns:
- the devices used for training
-
getLoss
public Loss getLoss()
Gets the trainingLossfunction of the trainer.- Returns:
- the
Lossfunction
-
getModel
public Model getModel()
Returns the model used to create this trainer.- Returns:
- the model associated with this trainer
-
getExecutorService
public java.util.Optional<java.util.concurrent.ExecutorService> getExecutorService()
Returns theExecutorService.- Returns:
- the
ExecutorService
-
getEvaluators
public java.util.List<Evaluator> getEvaluators()
Gets allEvaluators.- Returns:
- the evaluators used during training
-
notifyListeners
public final void notifyListeners(java.util.function.Consumer<TrainingListener> listenerConsumer)
Executes a method on each of theTrainingListeners.- Parameters:
listenerConsumer- a consumer that executes the method
-
getTrainingResult
public TrainingResult getTrainingResult()
Returns theTrainingResult.- Returns:
- the
TrainingResult
-
finalize
protected void finalize() throws java.lang.Throwable- Overrides:
finalizein classjava.lang.Object- Throws:
java.lang.Throwable
-
close
public void close()
- Specified by:
closein interfacejava.lang.AutoCloseable
-
addMetric
public void addMetric(java.lang.String metricName, long begin)Helper to add a metric for a time difference.- Parameters:
metricName- the metric namebegin- the time difference start (this method is called at the time difference end)
-
-