Class EasyHpo


  • public abstract class EasyHpo
    extends java.lang.Object
    Helper for easy training with hyperparameters.
    • Constructor Detail

      • EasyHpo

        public EasyHpo()
    • Method Detail

      • fit

        public ai.djl.util.Pair<Model,​TrainingResult> fit()
                                                         throws java.io.IOException,
                                                                TranslateException
        Fits the model given the implemented abstract methods.
        Returns:
        the best model and training results
        Throws:
        java.io.IOException - for various exceptions depending on the dataset
        TranslateException - if there is an error while processing input
      • setupHyperParams

        protected abstract HpSet setupHyperParams()
        Returns the initial hyperparameters.
        Returns:
        the initial hyperparameters
      • getDataset

        protected abstract RandomAccessDataset getDataset​(Dataset.Usage usage)
                                                   throws java.io.IOException
        Returns the dataset to train with.
        Parameters:
        usage - the usage of the dataset
        Returns:
        the dataset to train with
        Throws:
        java.io.IOException - if the dataset could not be loaded
      • setupTrainingConfig

        protected abstract TrainingConfig setupTrainingConfig​(HpSet hpVals)
        Returns the TrainingConfig to use to train each hyperparameter set.
        Parameters:
        hpVals - the hyperparameters to train with
        Returns:
        the TrainingConfig to use to train each hyperparameter set
      • buildModel

        protected abstract Model buildModel​(HpSet hpVals)
        Builds the Model and Block to train.
        Parameters:
        hpVals - the hyperparameter values to use for the model
        Returns:
        the model to train
      • inputShape

        protected abstract Shape inputShape​(HpSet hpVals)
        Returns the input shape for the model.
        Parameters:
        hpVals - the hyperparameter values for the model
        Returns:
        returns the model input shape
      • numEpochs

        protected abstract int numEpochs​(HpSet hpVals)
        Returns the number of epochs to train for the current hyperparameter set.
        Parameters:
        hpVals - the current hyperparameter set
        Returns:
        the number of epochs
      • numHyperParameterTests

        protected abstract int numHyperParameterTests()
        Returns the number of hyperparameter sets to train with.
        Returns:
        the number of hyperparameter sets to train with
      • saveModel

        protected void saveModel​(Model model,
                                 TrainingResult result)
                          throws java.io.IOException
        Saves the best hyperparameter set.
        Parameters:
        model - the model to save
        result - the training result for training with this model's hyperparameters
        Throws:
        java.io.IOException - if the model could not be saved