Package ai.djl.training.hyperparameter
Class EasyHpo
- java.lang.Object
-
- ai.djl.training.hyperparameter.EasyHpo
-
public abstract class EasyHpo extends java.lang.ObjectHelper for easy training with hyperparameters.
-
-
Constructor Summary
Constructors Constructor Description EasyHpo()
-
Method Summary
All Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description protected abstract ModelbuildModel(HpSet hpVals)ai.djl.util.Pair<Model,TrainingResult>fit()Fits the model given the implemented abstract methods.protected abstract RandomAccessDatasetgetDataset(Dataset.Usage usage)Returns the dataset to train with.protected abstract ShapeinputShape(HpSet hpVals)Returns the input shape for the model.protected abstract intnumEpochs(HpSet hpVals)Returns the number of epochs to train for the current hyperparameter set.protected abstract intnumHyperParameterTests()Returns the number of hyperparameter sets to train with.protected voidsaveModel(Model model, TrainingResult result)Saves the best hyperparameter set.protected abstract HpSetsetupHyperParams()Returns the initial hyperparameters.protected abstract TrainingConfigsetupTrainingConfig(HpSet hpVals)Returns theTrainingConfigto use to train each hyperparameter set.
-
-
-
Method Detail
-
fit
public ai.djl.util.Pair<Model,TrainingResult> fit() throws java.io.IOException, TranslateException
Fits the model given the implemented abstract methods.- Returns:
- the best model and training results
- Throws:
java.io.IOException- for various exceptions depending on the datasetTranslateException- if there is an error while processing input
-
setupHyperParams
protected abstract HpSet setupHyperParams()
Returns the initial hyperparameters.- Returns:
- the initial hyperparameters
-
getDataset
protected abstract RandomAccessDataset getDataset(Dataset.Usage usage) throws java.io.IOException
Returns the dataset to train with.- Parameters:
usage- the usage of the dataset- Returns:
- the dataset to train with
- Throws:
java.io.IOException- if the dataset could not be loaded
-
setupTrainingConfig
protected abstract TrainingConfig setupTrainingConfig(HpSet hpVals)
Returns theTrainingConfigto use to train each hyperparameter set.- Parameters:
hpVals- the hyperparameters to train with- Returns:
- the
TrainingConfigto use to train each hyperparameter set
-
buildModel
protected abstract Model buildModel(HpSet hpVals)
- Parameters:
hpVals- the hyperparameter values to use for the model- Returns:
- the model to train
-
inputShape
protected abstract Shape inputShape(HpSet hpVals)
Returns the input shape for the model.- Parameters:
hpVals- the hyperparameter values for the model- Returns:
- returns the model input shape
-
numEpochs
protected abstract int numEpochs(HpSet hpVals)
Returns the number of epochs to train for the current hyperparameter set.- Parameters:
hpVals- the current hyperparameter set- Returns:
- the number of epochs
-
numHyperParameterTests
protected abstract int numHyperParameterTests()
Returns the number of hyperparameter sets to train with.- Returns:
- the number of hyperparameter sets to train with
-
saveModel
protected void saveModel(Model model, TrainingResult result) throws java.io.IOException
Saves the best hyperparameter set.- Parameters:
model- the model to saveresult- the training result for training with this model's hyperparameters- Throws:
java.io.IOException- if the model could not be saved
-
-