public class DefaultTrainingConfig extends java.lang.Object implements TrainingConfig
DefaultTrainingConfig is an implementation of the TrainingConfig interface.| Constructor and Description |
|---|
DefaultTrainingConfig(Initializer initializer,
Loss loss)
Creates an instance of
DefaultTrainingConfig with the given Initializer. |
| Modifier and Type | Method and Description |
|---|---|
DefaultTrainingConfig |
addTrainingMetric(TrainingMetric trainingMetric)
Adds a
TrainingMetric that needs to be computed during training. |
int |
getBatchSize()
Gets the batch size that must be used during training.
|
Device[] |
getDevices()
Gets the
Device that are available for computation. |
Initializer |
getInitializer()
Gets the
Initializer to initialize the parameters of the model. |
Loss |
getLossFunction()
Gets the
Loss function to compute the loss against. |
Optimizer |
getOptimizer()
Gets the
Optimizer to use during training. |
java.util.List<TrainingMetric> |
getTrainingMetrics()
Returns the list of
TrainingMetric that should be computed during training. |
DefaultTrainingConfig |
setBatchSize(int batchSize)
Sets the size of a batch for training.
|
DefaultTrainingConfig |
setDevices(Device[] devices)
Sets the array of
Device available for training. |
DefaultTrainingConfig |
setOptimizer(Optimizer optimizer)
Sets the
Optimizer used during training. |
public DefaultTrainingConfig(Initializer initializer, Loss loss)
DefaultTrainingConfig with the given Initializer.initializer - the initializer to initialize the parameters withloss - the loss to use for trainingpublic DefaultTrainingConfig setDevices(Device[] devices)
Device available for training.devices - an array of devices to be setDefaultTrainingConfigpublic DefaultTrainingConfig setOptimizer(Optimizer optimizer)
Optimizer used during training.optimizer - the optimizer to be setDefaultTrainingConfigpublic DefaultTrainingConfig addTrainingMetric(TrainingMetric trainingMetric)
TrainingMetric that needs to be computed during training.trainingMetric - the training metric to be addedDefaultTrainingConfigpublic DefaultTrainingConfig setBatchSize(int batchSize)
batchSize - the batch sizeDefaultTrainingConfigpublic Device[] getDevices()
Device that are available for computation.
This is necessary for a Trainer as it needs to know what kind of device it is
running on, and how many devices it is running on.
getDevices in interface TrainingConfigDevicepublic Initializer getInitializer()
Initializer to initialize the parameters of the model.getInitializer in interface TrainingConfigInitializerpublic Optimizer getOptimizer()
Optimizer to use during training.getOptimizer in interface TrainingConfigOptimizerpublic Loss getLossFunction()
Loss function to compute the loss against.getLossFunction in interface TrainingConfigLoss functionpublic java.util.List<TrainingMetric> getTrainingMetrics()
TrainingMetric that should be computed during training.getTrainingMetrics in interface TrainingConfigTrainingMetricpublic int getBatchSize()
getBatchSize in interface TrainingConfig