Package org.nd4j.autodiff.samediff
Class TrainingConfig.Builder
- java.lang.Object
-
- org.nd4j.autodiff.samediff.TrainingConfig.Builder
-
- Enclosing class:
- TrainingConfig
public static class TrainingConfig.Builder extends Object
-
-
Constructor Summary
Constructors Constructor Description Builder()
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description TrainingConfig.BuilderaddEvaluations(boolean validation, @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations)Add requested evaluations for a parm/variable, for either training or validation.TrainingConfig.BuilderaddRegularization(Regularization... regularizations)Add regularization to all trainable parameters in the networkTrainingConfigbuild()TrainingConfig.BuilderdataSetFeatureMapping(String... dataSetFeatureMapping)Set the name of the placeholders/variables that should be set using the feature INDArray(s) from the DataSet or MultiDataSet.TrainingConfig.BuilderdataSetFeatureMapping(List<String> dataSetFeatureMapping)Set the name of the placeholders/variables that should be set using the feature INDArray(s) from the DataSet or MultiDataSet.TrainingConfig.BuilderdataSetFeatureMaskMapping(String... dataSetFeatureMaskMapping)TrainingConfig.BuilderdataSetFeatureMaskMapping(List<String> dataSetFeatureMaskMapping)Set the name of the placeholders/variables that should be set using the feature mask INDArray(s) from the DataSet or MultiDataSet.TrainingConfig.BuilderdataSetLabelMapping(String... dataSetLabelMapping)Set the name of the placeholders/variables that should be set using the labels INDArray(s) from the DataSet or MultiDataSet.TrainingConfig.BuilderdataSetLabelMapping(List<String> dataSetLabelMapping)Set the name of the placeholders/variables that should be set using the labels INDArray(s) from the DataSet or MultiDataSet.TrainingConfig.BuilderdataSetLabelMaskMapping(String... dataSetLabelMaskMapping)TrainingConfig.BuilderdataSetLabelMaskMapping(List<String> dataSetLabelMaskMapping)Set the name of the placeholders/variables that should be set using the label mask INDArray(s) from the DataSet or MultiDataSet.TrainingConfig.BuilderinitialLossDataType(DataType initialLossDataType)Set the initial loss data type, defaults toDataType.FLOAT- when setting a data type for a loss function we need a beginning data type to compute the gradients.TrainingConfig.Builderl1(double l1)Sets the L1 regularization coefficient for all trainable parameters.TrainingConfig.Builderl2(double l2)Sets the L2 regularization coefficient for all trainable parameters.TrainingConfig.BuildermarkLabelsUnused()Calling this method will mark the label as unused.TrainingConfig.Builderminimize(boolean minimize)Sets whether the loss function should be minimized (true) or maximized (false).
The loss function is usually minimized in SGD.
Default: true.TrainingConfig.Builderminimize(String... lossVariables)TrainingConfig.Builderregularization(List<Regularization> regularization)Set the regularization for all trainable parameters in the network.TrainingConfig.Builderregularization(Regularization... regularization)Set the regularization for all trainable parameters in the network.TrainingConfig.BuilderskipBuilderValidation(boolean skip)TrainingConfig.BuildertrainEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations)Add requested History training evaluations for a parm/variable.TrainingConfig.BuildertrainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations)Add requested History training evaluations for a parm/variable.TrainingConfig.Builderupdater(IUpdater updater)TrainingConfig.BuildervalidationEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations)Add requested History validation evaluations for a parm/variable.TrainingConfig.BuildervalidationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations)Add requested History validation evaluations for a parm/variable.TrainingConfig.BuilderweightDecay(double coefficient, boolean applyLR)Add weight decay regularization for all trainable parameters.
-
-
-
Method Detail
-
initialLossDataType
public TrainingConfig.Builder initialLossDataType(DataType initialLossDataType)
Set the initial loss data type, defaults toDataType.FLOAT- when setting a data type for a loss function we need a beginning data type to compute the gradients. In order to do so, we need to set an initial number of zero that acts as the initial gradient. This initial loss data type controls the data type of that number. This is critical when wanting more fine grained control over the data types used in the training process.- Parameters:
initialLossDataType- the initial loss data type- Returns:
-
updater
public TrainingConfig.Builder updater(IUpdater updater)
Set the updater (such asAdam,Nesterovsetc. This is also how the learning rate (or learning rate schedule) is set.- Parameters:
updater- Updater to set
-
l1
public TrainingConfig.Builder l1(double l1)
Sets the L1 regularization coefficient for all trainable parameters. Must be >= 0.
SeeL1Regularizationfor more details- Parameters:
l1- L1 regularization coefficient
-
l2
public TrainingConfig.Builder l2(double l2)
Sets the L2 regularization coefficient for all trainable parameters. Must be >= 0.
Note: Generally,WeightDecay(set viaweightDecay(double,boolean)should be preferred to L2 regularization. SeeWeightDecayjavadoc for further details.
Note: L2 regularization and weight decay usually should not be used together; if any weight decay (or L2) has been added for the biases, these will be removed first.- See Also:
weightDecay(double, boolean)
-
weightDecay
public TrainingConfig.Builder weightDecay(double coefficient, boolean applyLR)
Add weight decay regularization for all trainable parameters. SeeWeightDecayfor more details.
Note: values set by this method will be applied to all applicable layers in the network, unless a different value is explicitly set on a given layer. In other words: values set via this method are used as the default value, and can be overridden on a per-layer basis.- Parameters:
coefficient- Weight decay regularization coefficientapplyLR- Whether the learning rate should be multiplied in when performing weight decay updates. SeeWeightDecayfor more details.
-
addRegularization
public TrainingConfig.Builder addRegularization(Regularization... regularizations)
Add regularization to all trainable parameters in the network- Parameters:
regularizations- Regularization type(s) to add
-
regularization
public TrainingConfig.Builder regularization(Regularization... regularization)
Set the regularization for all trainable parameters in the network. Note that if any existing regularization types have been added, they will be removed- Parameters:
regularization- Regularization type(s) to add
-
regularization
public TrainingConfig.Builder regularization(List<Regularization> regularization)
Set the regularization for all trainable parameters in the network. Note that if any existing regularization types have been added, they will be removed- Parameters:
regularization- Regularization type(s) to add
-
minimize
public TrainingConfig.Builder minimize(boolean minimize)
Sets whether the loss function should be minimized (true) or maximized (false).
The loss function is usually minimized in SGD.
Default: true.- Parameters:
minimize- True to minimize, false to maximize
-
dataSetFeatureMapping
public TrainingConfig.Builder dataSetFeatureMapping(String... dataSetFeatureMapping)
Set the name of the placeholders/variables that should be set using the feature INDArray(s) from the DataSet or MultiDataSet. For example, if the network had 2 inputs called "input1" and "input2" and the MultiDataSet features should be mapped withMultiDataSet.getFeatures(0)->"input1"andMultiDataSet.getFeatures(1)->"input2", then this should be set toList<>("input1", "input2").- Parameters:
dataSetFeatureMapping- Name of the variables/placeholders that the feature arrays should be mapped to
-
dataSetFeatureMapping
public TrainingConfig.Builder dataSetFeatureMapping(List<String> dataSetFeatureMapping)
Set the name of the placeholders/variables that should be set using the feature INDArray(s) from the DataSet or MultiDataSet. For example, if the network had 2 inputs called "input1" and "input2" and the MultiDataSet features should be mapped withMultiDataSet.getFeatures(0)->"input1"andMultiDataSet.getFeatures(1)->"input2", then this should be set to"input1", "input2".- Parameters:
dataSetFeatureMapping- Name of the variables/placeholders that the feature arrays should be mapped to
-
dataSetLabelMapping
public TrainingConfig.Builder dataSetLabelMapping(String... dataSetLabelMapping)
Set the name of the placeholders/variables that should be set using the labels INDArray(s) from the DataSet or MultiDataSet. For example, if the network had 2 labels called "label1" and "label2" and the MultiDataSet labels should be mapped withMultiDataSet.getLabel(0)->"label1"andMultiDataSet.getLabels(1)->"label", then this should be set to"label1", "label2".- Parameters:
dataSetLabelMapping- Name of the variables/placeholders that the label arrays should be mapped to
-
dataSetLabelMapping
public TrainingConfig.Builder dataSetLabelMapping(List<String> dataSetLabelMapping)
Set the name of the placeholders/variables that should be set using the labels INDArray(s) from the DataSet or MultiDataSet. For example, if the network had 2 labels called "label1" and "label2" and the MultiDataSet labels should be mapped withMultiDataSet.getLabel(0)->"label1"andMultiDataSet.getLabels(1)->"label", then this should be set to"label1", "label2".- Parameters:
dataSetLabelMapping- Name of the variables/placeholders that the label arrays should be mapped to
-
markLabelsUnused
public TrainingConfig.Builder markLabelsUnused()
Calling this method will mark the label as unused. This is basically a way to turn off label mapping validation in TrainingConfig builder, for training models without labels.
Put another way: usually you need to calldataSetLabelMapping(String...)to set labels, this method allows you to say that the DataSet/MultiDataSet labels aren't used in training.
-
dataSetFeatureMaskMapping
public TrainingConfig.Builder dataSetFeatureMaskMapping(String... dataSetFeatureMaskMapping)
-
dataSetFeatureMaskMapping
public TrainingConfig.Builder dataSetFeatureMaskMapping(List<String> dataSetFeatureMaskMapping)
Set the name of the placeholders/variables that should be set using the feature mask INDArray(s) from the DataSet or MultiDataSet. For example, if the network had 2 mask variables called "mask1" and "mask2" and the MultiDataSet features masks should be mapped withMultiDataSet.getFeatureMaskArray(0)->"mask1"andMultiDataSet.getFeatureMaskArray(1)->"mask2", then this should be set to"mask1", "mask2".- Parameters:
dataSetFeatureMaskMapping- Name of the variables/placeholders that the feature arrays should be mapped to
-
dataSetLabelMaskMapping
public TrainingConfig.Builder dataSetLabelMaskMapping(String... dataSetLabelMaskMapping)
-
dataSetLabelMaskMapping
public TrainingConfig.Builder dataSetLabelMaskMapping(List<String> dataSetLabelMaskMapping)
Set the name of the placeholders/variables that should be set using the label mask INDArray(s) from the DataSet or MultiDataSet. For example, if the network had 2 mask variables called "mask1" and "mask2" and the MultiDataSet label masks should be mapped withMultiDataSet.getLabelMaskArray(0)->"mask1"andMultiDataSet.getLabelMaskArray(1)->"mask2", then this should be set to"mask1", "mask2".- Parameters:
dataSetLabelMaskMapping- Name of the variables/placeholders that the feature arrays should be mapped to
-
skipBuilderValidation
public TrainingConfig.Builder skipBuilderValidation(boolean skip)
-
minimize
public TrainingConfig.Builder minimize(String... lossVariables)
-
trainEvaluation
public TrainingConfig.Builder trainEvaluation(@NonNull @NonNull String variableName, int labelIndex, @NonNull @NonNull IEvaluation... evaluations)
Add requested History training evaluations for a parm/variable. These evaluations will be reported in theHistoryobject returned by fit.- Parameters:
variableName- The variable to evaluatelabelIndex- The index of the label to evaluate againstevaluations- The evaluations to run
-
trainEvaluation
public TrainingConfig.Builder trainEvaluation(@NonNull @NonNull SDVariable variable, int labelIndex, @NonNull @NonNull IEvaluation... evaluations)
Add requested History training evaluations for a parm/variable. These evaluations will be reported in theHistoryobject returned by fit.- Parameters:
variable- The variable to evaluatelabelIndex- The index of the label to evaluate againstevaluations- The evaluations to run
-
validationEvaluation
public TrainingConfig.Builder validationEvaluation(@NonNull @NonNull String variableName, int labelIndex, @NonNull @NonNull IEvaluation... evaluations)
Add requested History validation evaluations for a parm/variable. These evaluations will be reported in theHistoryobject returned by fit.- Parameters:
variableName- The variable to evaluatelabelIndex- The index of the label to evaluate againstevaluations- The evaluations to run
-
validationEvaluation
public TrainingConfig.Builder validationEvaluation(@NonNull @NonNull SDVariable variable, int labelIndex, @NonNull @NonNull IEvaluation... evaluations)
Add requested History validation evaluations for a parm/variable. These evaluations will be reported in theHistoryobject returned by fit.- Parameters:
variable- The variable to evaluatelabelIndex- The index of the label to evaluate againstevaluations- The evaluations to run
-
addEvaluations
public TrainingConfig.Builder addEvaluations(boolean validation, @NonNull @NonNull String variableName, int labelIndex, @NonNull @NonNull IEvaluation... evaluations)
Add requested evaluations for a parm/variable, for either training or validation. These evaluations will be reported in theHistoryobject returned by fit.- Parameters:
validation- Whether to add these evaluations as validation or trainingvariableName- The variable to evaluatelabelIndex- The index of the label to evaluate againstevaluations- The evaluations to run
-
build
public TrainingConfig build()
-
-