public abstract class EasyHpo
extends java.lang.Object
| Constructor and Description |
|---|
EasyHpo() |
| Modifier and Type | Method and Description |
|---|---|
protected abstract Model |
buildModel(HpSet hpVals)
|
ai.djl.util.Pair<Model,TrainingResult> |
fit()
Fits the model given the implemented abstract methods.
|
protected abstract RandomAccessDataset |
getDataset(Dataset.Usage usage)
Returns the dataset to train with.
|
protected abstract Shape |
inputShape(HpSet hpVals)
Returns the input shape for the model.
|
protected abstract int |
numEpochs(HpSet hpVals)
Returns the number of epochs to train for the current hyperparameter set.
|
protected abstract int |
numHyperParameterTests()
Returns the number of hyperparameter sets to train with.
|
protected void |
saveModel(Model model,
TrainingResult result)
Saves the best hyperparameter set.
|
protected abstract HpSet |
setupHyperParams()
Returns the initial hyperparameters.
|
protected abstract TrainingConfig |
setupTrainingConfig(HpSet hpVals)
Returns the
TrainingConfig to use to train each hyperparameter set. |
public ai.djl.util.Pair<Model,TrainingResult> fit() throws java.io.IOException, TranslateException
java.io.IOException - for various exceptions depending on the datasetTranslateException - if there is an error while processing inputprotected abstract HpSet setupHyperParams()
protected abstract RandomAccessDataset getDataset(Dataset.Usage usage) throws java.io.IOException
usage - the usage of the datasetjava.io.IOException - if the dataset could not be loadedprotected abstract TrainingConfig setupTrainingConfig(HpSet hpVals)
TrainingConfig to use to train each hyperparameter set.hpVals - the hyperparameters to train withTrainingConfig to use to train each hyperparameter setprotected abstract Model buildModel(HpSet hpVals)
hpVals - the hyperparameter values to use for the modelprotected abstract Shape inputShape(HpSet hpVals)
hpVals - the hyperparameter values for the modelprotected abstract int numEpochs(HpSet hpVals)
hpVals - the current hyperparameter setprotected abstract int numHyperParameterTests()
protected void saveModel(Model model, TrainingResult result) throws java.io.IOException
model - the model to saveresult - the training result for training with this model's hyperparametersjava.io.IOException - if the model could not be saved