Class EarlyStoppingListener

  • All Implemented Interfaces:
    TrainingListener

    public final class EarlyStoppingListener
    extends java.lang.Object
    implements TrainingListener
    Listener that allows the training to be stopped early if the validation loss is not improving, or if time has expired.

    Usage: Add this listener to the training config, and add it as the last one.

      new DefaultTrainingConfig(...)
            .addTrainingListeners(EarlyStoppingListener.builder()
                    .setEpochPatience(1)
                    .setEarlyStopPctImprovement(1)
                    .setMaxDuration(Duration.ofMinutes(42))
                    .setMinEpochs(1)
                    .build()
            );
     

    Then surround the fit with a try catch that catches the EarlyStoppingListener.EarlyStoppedException.
    Example:

     try {
       EasyTrain.fit(trainer, 5, trainDataset, testDataset);
     } catch (EarlyStoppingListener.EarlyStoppedException e) {
       // handle early stopping
       log.info("Stopped early at epoch {} because: {}", e.getEpoch(), e.getMessage());
     }
     

    Note: Ensure that Metrics are set on the trainer.
    • Method Detail

      • onEpoch

        public void onEpoch​(Trainer trainer)
        Listens to the end of an epoch during training.
        Specified by:
        onEpoch in interface TrainingListener
        Parameters:
        trainer - the trainer the listener is attached to
      • onTrainingBatch

        public void onTrainingBatch​(Trainer trainer,
                                    TrainingListener.BatchData batchData)
        Listens to the end of training one batch of data during training.
        Specified by:
        onTrainingBatch in interface TrainingListener
        Parameters:
        trainer - the trainer the listener is attached to
        batchData - the data from the batch
      • onValidationBatch

        public void onValidationBatch​(Trainer trainer,
                                      TrainingListener.BatchData batchData)
        Listens to the end of validating one batch of data during validation.
        Specified by:
        onValidationBatch in interface TrainingListener
        Parameters:
        trainer - the trainer the listener is attached to
        batchData - the data from the batch
      • onTrainingBegin

        public void onTrainingBegin​(Trainer trainer)
        Listens to the beginning of training.
        Specified by:
        onTrainingBegin in interface TrainingListener
        Parameters:
        trainer - the trainer the listener is attached to
      • onTrainingEnd

        public void onTrainingEnd​(Trainer trainer)
        Listens to the end of training.
        Specified by:
        onTrainingEnd in interface TrainingListener
        Parameters:
        trainer - the trainer the listener is attached to