Package ai.djl.training
Class DefaultTrainingConfig
- java.lang.Object
-
- ai.djl.training.DefaultTrainingConfig
-
- All Implemented Interfaces:
TrainingConfig
public class DefaultTrainingConfig extends java.lang.Object implements TrainingConfig
DefaultTrainingConfigis an implementation of theTrainingConfiginterface.
-
-
Constructor Summary
Constructors Constructor Description DefaultTrainingConfig(Loss loss)Creates an instance ofDefaultTrainingConfigwith the givenLoss.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description DefaultTrainingConfigaddEvaluator(Evaluator evaluator)Adds anEvaluatorthat needs to be computed during training.<T extends Evaluator>
DefaultTrainingConfigaddEvaluators(java.util.Collection<T> evaluators)Adds multipleEvaluators that needs to be computed during training.DefaultTrainingConfigaddTrainingListeners(TrainingListener... listeners)AddsTrainingListeners for training.Device[]getDevices()Gets theDevicethat are available for computation.java.util.List<Evaluator>getEvaluators()Returns the list ofEvaluators that should be computed during training.java.util.concurrent.ExecutorServicegetExecutorService()Gets theExecutorServicefor parallelization.ai.djl.util.PairList<Initializer,java.util.function.Predicate<Parameter>>getInitializers()Gets a list ofInitializerand Predicate to initialize the parameters of the model.LossgetLossFunction()Gets theLossfunction to compute the loss against.OptimizergetOptimizer()Gets theOptimizerto use during training.java.util.List<TrainingListener>getTrainingListeners()Returns the list ofTrainingListeners that should be used during training.DefaultTrainingConfigoptDevices(Device[] devices)Sets the array ofDeviceavailable for training.DefaultTrainingConfigoptExecutorService()Sets theExecutorServicewith the globalForkJoinPool.commonPool().DefaultTrainingConfigoptExecutorService(java.util.concurrent.ExecutorService executorService)Sets theExecutorServiceto train with multiple threads.DefaultTrainingConfigoptInitializer(Initializer initializer, Parameter.Type type)Sets theInitializerto use for the parameters (default from paper).DefaultTrainingConfigoptInitializer(Initializer initializer, java.lang.String name)Sets theInitializerto use for the parameters (default from paper).DefaultTrainingConfigoptInitializer(Initializer initializer, java.util.function.Predicate<Parameter> predicate)Sets theInitializerto use for the parameters (default from paper).DefaultTrainingConfigoptOptimizer(Optimizer optimizer)
-
-
-
Constructor Detail
-
DefaultTrainingConfig
public DefaultTrainingConfig(Loss loss)
Creates an instance ofDefaultTrainingConfigwith the givenLoss.DefaultTrainingConfigcreates a defaultTrainingConfig,Adamas optimiser, and the givenLoss. The evaluators and listeners are left to the user's discretion.- Parameters:
loss- the loss to use for training
-
-
Method Detail
-
optInitializer
public DefaultTrainingConfig optInitializer(Initializer initializer, Parameter.Type type)
Sets theInitializerto use for the parameters (default from paper).- Parameters:
initializer- the initialer to use for the parameterstype- theParameter.Typeof the parameters- Returns:
- this
DefaultTrainingConfig
-
optInitializer
public DefaultTrainingConfig optInitializer(Initializer initializer, java.lang.String name)
Sets theInitializerto use for the parameters (default from paper).- Parameters:
initializer- the initialer to use for the parametersname- the name of the parameter- Returns:
- this
DefaultTrainingConfig
-
optInitializer
public DefaultTrainingConfig optInitializer(Initializer initializer, java.util.function.Predicate<Parameter> predicate)
Sets theInitializerto use for the parameters (default from paper).- Parameters:
initializer- the initialer to use for the parameterspredicate- the predicate to identify parameter- Returns:
- this
DefaultTrainingConfig
-
optDevices
public DefaultTrainingConfig optDevices(Device[] devices)
Sets the array ofDeviceavailable for training.- Parameters:
devices- an array of devices to be set- Returns:
- this
DefaultTrainingConfig
-
optOptimizer
public DefaultTrainingConfig optOptimizer(Optimizer optimizer)
- Parameters:
optimizer- the optimizer to be set- Returns:
- this
DefaultTrainingConfig
-
optExecutorService
public DefaultTrainingConfig optExecutorService()
Sets theExecutorServicewith the globalForkJoinPool.commonPool().- Returns:
- this
DefaultTrainingConfig
-
optExecutorService
public DefaultTrainingConfig optExecutorService(java.util.concurrent.ExecutorService executorService)
Sets theExecutorServiceto train with multiple threads.- Parameters:
executorService- the executor service- Returns:
- this
DefaultTrainingConfig
-
addEvaluators
public <T extends Evaluator> DefaultTrainingConfig addEvaluators(java.util.Collection<T> evaluators)
Adds multipleEvaluators that needs to be computed during training.- Type Parameters:
T- the type of evaluator to be added- Parameters:
evaluators- the evaluators to be added- Returns:
- this
DefaultTrainingConfig
-
addEvaluator
public DefaultTrainingConfig addEvaluator(Evaluator evaluator)
Adds anEvaluatorthat needs to be computed during training.- Parameters:
evaluator- the evaluator to be added- Returns:
- this
DefaultTrainingConfig
-
addTrainingListeners
public DefaultTrainingConfig addTrainingListeners(TrainingListener... listeners)
AddsTrainingListeners for training.- Parameters:
listeners- theTrainingListeners to add- Returns:
- this
DefaultTrainingConfig
-
getDevices
public Device[] getDevices()
Gets theDevicethat are available for computation.This is necessary for a
Traineras it needs to know what kind of device it is running on, and how many devices it is running on.- Specified by:
getDevicesin interfaceTrainingConfig- Returns:
- an array of
Device
-
getInitializers
public ai.djl.util.PairList<Initializer,java.util.function.Predicate<Parameter>> getInitializers()
Gets a list ofInitializerand Predicate to initialize the parameters of the model.- Specified by:
getInitializersin interfaceTrainingConfig- Returns:
- an
Initializer
-
getOptimizer
public Optimizer getOptimizer()
Gets theOptimizerto use during training.- Specified by:
getOptimizerin interfaceTrainingConfig- Returns:
- an
Optimizer
-
getLossFunction
public Loss getLossFunction()
Gets theLossfunction to compute the loss against.- Specified by:
getLossFunctionin interfaceTrainingConfig- Returns:
- a
Lossfunction
-
getExecutorService
public java.util.concurrent.ExecutorService getExecutorService()
Gets theExecutorServicefor parallelization.- Specified by:
getExecutorServicein interfaceTrainingConfig- Returns:
- an
ExecutorService
-
getEvaluators
public java.util.List<Evaluator> getEvaluators()
Returns the list ofEvaluators that should be computed during training.- Specified by:
getEvaluatorsin interfaceTrainingConfig- Returns:
- a list of
Evaluators
-
getTrainingListeners
public java.util.List<TrainingListener> getTrainingListeners()
Returns the list ofTrainingListeners that should be used during training.- Specified by:
getTrainingListenersin interfaceTrainingConfig- Returns:
- a list of
TrainingListeners
-
-