Package org.nd4j.autodiff.samediff
Class TrainingConfig
- java.lang.Object
-
- org.nd4j.autodiff.samediff.TrainingConfig
-
public class TrainingConfig extends Object
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classTrainingConfig.Builder
-
Constructor Summary
Constructors Modifier Constructor Description protectedTrainingConfig(IUpdater updater, List<Regularization> regularization, boolean minimize, List<String> dataSetFeatureMapping, List<String> dataSetLabelMapping, List<String> dataSetFeatureMaskMapping, List<String> dataSetLabelMaskMapping, List<String> lossVariables, Map<String,List<IEvaluation>> trainEvaluations, Map<String,Integer> trainEvaluationLabels, Map<String,List<IEvaluation>> validationEvaluations, Map<String,Integer> validationEvaluationLabels, DataType initialLossDataType)TrainingConfig(IUpdater updater, List<Regularization> regularization, boolean minimize, List<String> dataSetFeatureMapping, List<String> dataSetLabelMapping, List<String> dataSetFeatureMaskMapping, List<String> dataSetLabelMaskMapping, List<String> lossVariables, DataType initialLossDataType)Create a training configuration suitable for training both single input/output and multi input/output networks.
See also theTrainingConfig.Builderfor creating a TrainingConfigTrainingConfig(IUpdater updater, List<Regularization> regularization, String dataSetFeatureMapping, String dataSetLabelMapping)Create a training configuration suitable for training a single input, single output network.
See also theTrainingConfig.Builderfor creating a TrainingConfig
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description static TrainingConfig.Builderbuilder()static TrainingConfigfromJson(@NonNull String json)voidincrementEpochCount()Increment the epoch count by 1voidincrementIterationCount()Increment the iteration count by 1intlabelIdx(String s)Get the index of the label array that the specified variable is associated withstatic voidremoveInstances(List<?> list, Class<?> remove)Remove any instances of the specified type from the list.static voidremoveInstancesWithWarning(List<?> list, Class<?> remove, String warning)StringtoJson()
-
-
-
Constructor Detail
-
TrainingConfig
public TrainingConfig(IUpdater updater, List<Regularization> regularization, String dataSetFeatureMapping, String dataSetLabelMapping)
Create a training configuration suitable for training a single input, single output network.
See also theTrainingConfig.Builderfor creating a TrainingConfig- Parameters:
updater- The updater configuration to usedataSetFeatureMapping- The name of the placeholder/variable that should be set using the feature INDArray from the DataSet (or the first/only feature from a MultiDataSet). For example, if the network input placeholder was called "input" then this should be set to "input"dataSetLabelMapping- The name of the placeholder/variable that should be set using the label INDArray from the DataSet (or the first/only feature from a MultiDataSet). For example, if the network input placeholder was called "input" then this should be set to "input"
-
TrainingConfig
public TrainingConfig(IUpdater updater, List<Regularization> regularization, boolean minimize, List<String> dataSetFeatureMapping, List<String> dataSetLabelMapping, List<String> dataSetFeatureMaskMapping, List<String> dataSetLabelMaskMapping, List<String> lossVariables, DataType initialLossDataType)
Create a training configuration suitable for training both single input/output and multi input/output networks.
See also theTrainingConfig.Builderfor creating a TrainingConfig- Parameters:
updater- The updater configuration to useregularization- Regularization for all trainable parameters;\minimize- Set to true if the loss function should be minimized (usually true). False to maximizedataSetFeatureMapping- The name of the placeholders/variables that should be set using the feature INDArray(s) from the DataSet or MultiDataSet. For example, if the network had 2 inputs called "input1" and "input2" and the MultiDataSet features should be mapped withMultiDataSet.getFeatures(0)->"input1"andMultiDataSet.getFeatures(1)->"input2", then this should be set toList<>("input1", "input2").dataSetLabelMapping- As per dataSetFeatureMapping, but for the DataSet/MultiDataSet labelsdataSetFeatureMaskMapping- May be null. If non-null, the variables that the MultiDataSet feature mask arrays should be associated with.dataSetLabelMaskMapping- May be null. If non-null, the variables that the MultiDataSet label mask arrays should be associated with.
-
TrainingConfig
protected TrainingConfig(IUpdater updater, List<Regularization> regularization, boolean minimize, List<String> dataSetFeatureMapping, List<String> dataSetLabelMapping, List<String> dataSetFeatureMaskMapping, List<String> dataSetLabelMaskMapping, List<String> lossVariables, Map<String,List<IEvaluation>> trainEvaluations, Map<String,Integer> trainEvaluationLabels, Map<String,List<IEvaluation>> validationEvaluations, Map<String,Integer> validationEvaluationLabels, DataType initialLossDataType)
-
-
Method Detail
-
incrementIterationCount
public void incrementIterationCount()
Increment the iteration count by 1
-
incrementEpochCount
public void incrementEpochCount()
Increment the epoch count by 1
-
builder
public static TrainingConfig.Builder builder()
-
labelIdx
public int labelIdx(String s)
Get the index of the label array that the specified variable is associated with- Parameters:
s- Name of the variable- Returns:
- The index of the label variable, or -1 if not found
-
removeInstances
public static void removeInstances(List<?> list, Class<?> remove)
Remove any instances of the specified type from the list. This includes any subtypes.- Parameters:
list- List. May be nullremove- Type of objects to remove
-
removeInstancesWithWarning
public static void removeInstancesWithWarning(List<?> list, Class<?> remove, String warning)
-
toJson
public String toJson()
-
fromJson
public static TrainingConfig fromJson(@NonNull @NonNull String json)
-
-