public class DefaultTrainingConfig extends java.lang.Object implements TrainingConfig
DefaultTrainingConfig is an implementation of the TrainingConfig interface.| Constructor and Description |
|---|
DefaultTrainingConfig(Loss loss)
Creates an instance of
DefaultTrainingConfig with the given Loss. |
| Modifier and Type | Method and Description |
|---|---|
DefaultTrainingConfig |
addEvaluator(Evaluator evaluator)
Adds an
Evaluator that needs to be computed during training. |
<T extends Evaluator> |
addEvaluators(java.util.Collection<T> evaluators)
Adds multiple
Evaluators that needs to be computed during training. |
DefaultTrainingConfig |
addTrainingListeners(TrainingListener... listeners)
Adds
TrainingListeners for training. |
Device[] |
getDevices()
Gets the
Device that are available for computation. |
java.util.List<Evaluator> |
getEvaluators()
Returns the list of
Evaluators that should be computed during training. |
java.util.concurrent.ExecutorService |
getExecutorService()
Gets the
ExecutorService for parallelization. |
ai.djl.util.PairList<Initializer,java.util.function.Predicate<Parameter>> |
getInitializers()
Gets a list of
Initializer and Predicate to initialize the parameters of the model. |
Loss |
getLossFunction()
Gets the
Loss function to compute the loss against. |
Optimizer |
getOptimizer()
Gets the
Optimizer to use during training. |
java.util.List<TrainingListener> |
getTrainingListeners()
Returns the list of
TrainingListeners that should be used during training. |
DefaultTrainingConfig |
optDevices(Device[] devices)
Sets the array of
Device available for training. |
DefaultTrainingConfig |
optExecutorService()
Sets the
ExecutorService with the global ForkJoinPool.commonPool(). |
DefaultTrainingConfig |
optExecutorService(java.util.concurrent.ExecutorService executorService)
Sets the
ExecutorService to train with multiple threads. |
DefaultTrainingConfig |
optInitializer(Initializer initializer,
Parameter.Type type)
Sets the
Initializer to use for the parameters (default from paper). |
DefaultTrainingConfig |
optInitializer(Initializer initializer,
java.util.function.Predicate<Parameter> predicate)
Sets the
Initializer to use for the parameters (default from paper). |
DefaultTrainingConfig |
optInitializer(Initializer initializer,
java.lang.String name)
Sets the
Initializer to use for the parameters (default from paper). |
DefaultTrainingConfig |
optOptimizer(Optimizer optimizer)
|
public DefaultTrainingConfig(Loss loss)
DefaultTrainingConfig with the given Loss. DefaultTrainingConfig creates a default TrainingConfig, Adam as optimiser,
and the given Loss. The evaluators and listeners are left to the user's discretion.loss - the loss to use for trainingpublic DefaultTrainingConfig optInitializer(Initializer initializer, Parameter.Type type)
Initializer to use for the parameters (default from paper).initializer - the initialer to use for the parameterstype - the Parameter.Type of the parametersDefaultTrainingConfigpublic DefaultTrainingConfig optInitializer(Initializer initializer, java.lang.String name)
Initializer to use for the parameters (default from paper).initializer - the initialer to use for the parametersname - the name of the parameterDefaultTrainingConfigpublic DefaultTrainingConfig optInitializer(Initializer initializer, java.util.function.Predicate<Parameter> predicate)
Initializer to use for the parameters (default from paper).initializer - the initialer to use for the parameterspredicate - the predicate to identify parameterDefaultTrainingConfigpublic DefaultTrainingConfig optDevices(Device[] devices)
Device available for training.devices - an array of devices to be setDefaultTrainingConfigpublic DefaultTrainingConfig optOptimizer(Optimizer optimizer)
optimizer - the optimizer to be setDefaultTrainingConfigpublic DefaultTrainingConfig optExecutorService()
ExecutorService with the global ForkJoinPool.commonPool().DefaultTrainingConfigpublic DefaultTrainingConfig optExecutorService(java.util.concurrent.ExecutorService executorService)
ExecutorService to train with multiple threads.executorService - the executor serviceDefaultTrainingConfigpublic <T extends Evaluator> DefaultTrainingConfig addEvaluators(java.util.Collection<T> evaluators)
Evaluators that needs to be computed during training.T - the type of evaluator to be addedevaluators - the evaluators to be addedDefaultTrainingConfigpublic DefaultTrainingConfig addEvaluator(Evaluator evaluator)
Evaluator that needs to be computed during training.evaluator - the evaluator to be addedDefaultTrainingConfigpublic DefaultTrainingConfig addTrainingListeners(TrainingListener... listeners)
TrainingListeners for training.listeners - the TrainingListeners to addDefaultTrainingConfigpublic Device[] getDevices()
Device that are available for computation.
This is necessary for a Trainer as it needs to know what kind of device it is
running on, and how many devices it is running on.
getDevices in interface TrainingConfigDevicepublic ai.djl.util.PairList<Initializer,java.util.function.Predicate<Parameter>> getInitializers()
Initializer and Predicate to initialize the parameters of the model.getInitializers in interface TrainingConfigInitializerpublic Optimizer getOptimizer()
Optimizer to use during training.getOptimizer in interface TrainingConfigOptimizerpublic Loss getLossFunction()
Loss function to compute the loss against.getLossFunction in interface TrainingConfigLoss functionpublic java.util.concurrent.ExecutorService getExecutorService()
TrainingConfigExecutorService for parallelization.getExecutorService in interface TrainingConfigExecutorServicepublic java.util.List<Evaluator> getEvaluators()
Evaluators that should be computed during training.getEvaluators in interface TrainingConfigEvaluatorspublic java.util.List<TrainingListener> getTrainingListeners()
TrainingListeners that should be used during training.getTrainingListeners in interface TrainingConfigTrainingListeners