Class EasyTrain


  • public final class EasyTrain
    extends java.lang.Object
    Helper for easy training of a whole model, a trainining batch, or a validation batch.
    • Method Summary

      All Methods Static Methods Concrete Methods 
      Modifier and Type Method Description
      static void evaluateDataset​(Trainer trainer, Dataset testDataset)
      Evaluates the test dataset.
      static void fit​(Trainer trainer, int numEpoch, Dataset trainingDataset, Dataset validateDataset)
      Runs a basic epoch training experience with a given trainer.
      static void trainBatch​(Trainer trainer, Batch batch)
      Trains the model with one iteration of the given Batch of data.
      static void validateBatch​(Trainer trainer, Batch batch)
      Validates the given batch of data.
      • Methods inherited from class java.lang.Object

        clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
    • Method Detail

      • fit

        public static void fit​(Trainer trainer,
                               int numEpoch,
                               Dataset trainingDataset,
                               Dataset validateDataset)
                        throws java.io.IOException,
                               TranslateException
        Runs a basic epoch training experience with a given trainer.
        Parameters:
        trainer - the trainer to train for
        numEpoch - the number of epochs to train
        trainingDataset - the dataset to train on
        validateDataset - the dataset to validate against. Can be null for no validation
        Throws:
        java.io.IOException - for various exceptions depending on the dataset
        TranslateException - if there is an error while processing input
      • trainBatch

        public static void trainBatch​(Trainer trainer,
                                      Batch batch)
        Trains the model with one iteration of the given Batch of data.
        Parameters:
        trainer - the trainer to validate the batch with
        batch - a Batch that contains data, and its respective labels
        Throws:
        java.lang.IllegalArgumentException - if the batch engine does not match the trainer engine
      • validateBatch

        public static void validateBatch​(Trainer trainer,
                                         Batch batch)
        Validates the given batch of data.

        During validation, the evaluators and losses are computed, but gradients aren't computed, and parameters aren't updated.

        Parameters:
        trainer - the trainer to validate the batch with
        batch - a Batch of data
        Throws:
        java.lang.IllegalArgumentException - if the batch engine does not match the trainer engine
      • evaluateDataset

        public static void evaluateDataset​(Trainer trainer,
                                           Dataset testDataset)
                                    throws java.io.IOException,
                                           TranslateException
        Evaluates the test dataset.
        Parameters:
        trainer - the trainer to evaluate on
        testDataset - the test dataset to evaluate
        Throws:
        java.io.IOException - for various exceptions depending on the dataset
        TranslateException - if there is an error while processing input