Package ai.djl.training
Class EasyTrain
- java.lang.Object
-
- ai.djl.training.EasyTrain
-
public final class EasyTrain extends java.lang.ObjectHelper for easy training of a whole model, a trainining batch, or a validation batch.
-
-
Method Summary
All Methods Static Methods Concrete Methods Modifier and Type Method Description static voidevaluateDataset(Trainer trainer, Dataset testDataset)Evaluates the test dataset.static voidfit(Trainer trainer, int numEpoch, Dataset trainingDataset, Dataset validateDataset)Runs a basic epoch training experience with a given trainer.static voidtrainBatch(Trainer trainer, Batch batch)Trains the model with one iteration of the givenBatchof data.static voidvalidateBatch(Trainer trainer, Batch batch)Validates the given batch of data.
-
-
-
Method Detail
-
fit
public static void fit(Trainer trainer, int numEpoch, Dataset trainingDataset, Dataset validateDataset) throws java.io.IOException, TranslateException
Runs a basic epoch training experience with a given trainer.- Parameters:
trainer- the trainer to train fornumEpoch- the number of epochs to traintrainingDataset- the dataset to train onvalidateDataset- the dataset to validate against. Can be null for no validation- Throws:
java.io.IOException- for various exceptions depending on the datasetTranslateException- if there is an error while processing input
-
trainBatch
public static void trainBatch(Trainer trainer, Batch batch)
Trains the model with one iteration of the givenBatchof data.- Parameters:
trainer- the trainer to validate the batch withbatch- aBatchthat contains data, and its respective labels- Throws:
java.lang.IllegalArgumentException- if the batch engine does not match the trainer engine
-
validateBatch
public static void validateBatch(Trainer trainer, Batch batch)
Validates the given batch of data.During validation, the evaluators and losses are computed, but gradients aren't computed, and parameters aren't updated.
- Parameters:
trainer- the trainer to validate the batch withbatch- aBatchof data- Throws:
java.lang.IllegalArgumentException- if the batch engine does not match the trainer engine
-
evaluateDataset
public static void evaluateDataset(Trainer trainer, Dataset testDataset) throws java.io.IOException, TranslateException
Evaluates the test dataset.- Parameters:
trainer- the trainer to evaluate ontestDataset- the test dataset to evaluate- Throws:
java.io.IOException- for various exceptions depending on the datasetTranslateException- if there is an error while processing input
-
-