Class OrtTrainingSession
- All Implemented Interfaces:
AutoCloseable
Allows the inspection of the model's input and output nodes. Produced by an OrtEnvironment.
Most instance methods throw IllegalStateException if the session is closed and the
methods are called.
-
Method Summary
Modifier and TypeMethodDescriptionvoidaddProperty(String name, float value) Adds a float property to this training session checkpoint.voidaddProperty(String name, int value) Adds a int property to this training session checkpoint.voidaddProperty(String name, String value) Adds a String property to this training session checkpoint.voidclose()evalStep(Map<String, ? extends OnnxTensorLike> inputs) Performs a single evaluation step using the supplied inputs.evalStep(Map<String, ? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) Performs a single evaluation step using the supplied inputs.evalStep(Map<String, ? extends OnnxTensorLike> inputs, Map<String, ? extends OnnxValue> pinnedOutputs) Performs a single evaluation step using the supplied inputs.Performs a single evaluation step using the supplied inputs.evalStep(Map<String, ? extends OnnxTensorLike> inputs, Set<String> requestedOutputs, Map<String, ? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) Performs a single evaluation step using the supplied inputs.voidexportModelForInference(Path outputPath, String[] outputNames) Exports the evaluation model as a model suitable for inference, setting the desired nodes as output nodes.Returns an ordered set of the eval model input names.Returns an ordered set of the eval model output names.floatgetFloatProperty(String name) Gets a float property from this training session checkpoint.intgetIntProperty(String name) Gets a int property from this training session checkpoint.floatGets the current learning rate for this training session.getStringProperty(String name) Gets a String property from this training session checkpoint.Returns an ordered set of the train model input names.Returns an ordered set of the train model output names.voidEnsures the gradients are reset to zero before the next call totrainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>).voidApplies the gradient updates to the trainable parameters using the optimizer model.voidoptimizerStep(OrtSession.RunOptions runOptions) Applies the gradient updates to the trainable parameters using the optimizer model.voidregisterLinearLRScheduler(long warmupSteps, long totalSteps, float initialLearningRate) Registers a linear learning rate scheduler with linear warmup.voidsaveCheckpoint(Path outputPath, boolean saveOptimizer) Save out the training session state into the supplied checkpoint directory.voidUpdates the learning rate based on the registered learning rate scheduler.voidsetLearningRate(float learningRate) Sets the learning rate for the training session.static voidsetSeed(long seed) Sets the RNG seed used by ONNX Runtime.trainStep(Map<String, ? extends OnnxTensorLike> inputs) Performs a single step of training, accumulating the gradients.trainStep(Map<String, ? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) Performs a single step of training, accumulating the gradients.trainStep(Map<String, ? extends OnnxTensorLike> inputs, Map<String, ? extends OnnxValue> pinnedOutputs) Performs a single step of training, accumulating the gradients.Performs a single step of training, accumulating the gradients.trainStep(Map<String, ? extends OnnxTensorLike> inputs, Set<String> requestedOutputs, Map<String, ? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) Performs a single step of training, accumulating the gradients.
-
Method Details
-
getTrainInputNames
Returns an ordered set of the train model input names.- Returns:
- The training inputs.
-
getTrainOutputNames
Returns an ordered set of the train model output names.- Returns:
- The training outputs.
-
getEvalInputNames
Returns an ordered set of the eval model input names.- Returns:
- The evaluation inputs.
-
getEvalOutputNames
Returns an ordered set of the eval model output names.- Returns:
- The evaluation outputs.
-
addProperty
Adds a float property to this training session checkpoint.- Parameters:
name- The property name.value- The property value.- Throws:
OrtException- If the call failed.
-
addProperty
Adds a int property to this training session checkpoint.- Parameters:
name- The property name.value- The property value.- Throws:
OrtException- If the call failed.
-
addProperty
Adds a String property to this training session checkpoint.- Parameters:
name- The property name.value- The property value.- Throws:
OrtException- If the call failed.
-
getFloatProperty
Gets a float property from this training session checkpoint.- Parameters:
name- The property name.- Returns:
- The property value.
- Throws:
OrtException- If the property does not exist, or is of the wrong type.
-
getIntProperty
Gets a int property from this training session checkpoint.- Parameters:
name- The property name.- Returns:
- The property value.
- Throws:
OrtException- If the property does not exist, or is of the wrong type.
-
getStringProperty
Gets a String property from this training session checkpoint.- Parameters:
name- The property name.- Returns:
- The property value.
- Throws:
OrtException- If the property does not exist, or is of the wrong type.
-
close
public void close()- Specified by:
closein interfaceAutoCloseable
-
saveCheckpoint
Save out the training session state into the supplied checkpoint directory.- Parameters:
outputPath- Path to a checkpoint directory.saveOptimizer- Should the optimizer states be saved out.- Throws:
OrtException- If the native call failed.
-
lazyResetGrad
Ensures the gradients are reset to zero before the next call totrainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>).Note this is a lazy call, the gradients are cleared as part of running the next
trainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>)and not before.- Throws:
OrtException- If the native call failed.
-
setSeed
Sets the RNG seed used by ONNX Runtime.Note this setting is global across OrtTrainingSession instances.
- Parameters:
seed- The RNG seed.- Throws:
OrtException- If the native call failed.
-
trainStep
Performs a single step of training, accumulating the gradients.- Parameters:
inputs- The inputs (must include both the features and the target).- Returns:
- All outputs produced by the training step.
- Throws:
OrtException- If the native call failed.
-
trainStep
public OrtSession.Result trainStep(Map<String, ? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtExceptionPerforms a single step of training, accumulating the gradients.- Parameters:
inputs- The inputs (must include both the features and the target).runOptions- Run options for controlling this specific call.- Returns:
- All outputs produced by the training step.
- Throws:
OrtException- If the native call failed.
-
trainStep
public OrtSession.Result trainStep(Map<String, ? extends OnnxTensorLike> inputs, Set<String> requestedOutputs) throws OrtExceptionPerforms a single step of training, accumulating the gradients.- Parameters:
inputs- The inputs (must include both the features and the target).requestedOutputs- The requested outputs.- Returns:
- Requested outputs produced by the training step.
- Throws:
OrtException- If the native call failed.
-
trainStep
public OrtSession.Result trainStep(Map<String, ? extends OnnxTensorLike> inputs, Map<String, throws OrtException? extends OnnxValue> pinnedOutputs) Performs a single step of training, accumulating the gradients.The outputs are sorted based on the supplied map traversal order.
Note: pinned outputs are not owned by the
OrtSession.Resultobject, and are not closed when the result object is closed.- Parameters:
inputs- The inputs (must include both the features and the target).pinnedOutputs- The requested outputs which the user has allocated.- Returns:
- Requested outputs produced by the training step.
- Throws:
OrtException- If the native call failed.
-
trainStep
public OrtSession.Result trainStep(Map<String, ? extends OnnxTensorLike> inputs, Set<String> requestedOutputs, Map<String, throws OrtException? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) Performs a single step of training, accumulating the gradients.The outputs are sorted based on the supplied set traversal order with pinned outputs first, then requested outputs. An
IllegalArgumentExceptionis thrown if the same output name appears in both the requested outputs and the pinned outputs.Note: pinned outputs are not owned by the
OrtSession.Resultobject, and are not closed when the result object is closed.- Parameters:
inputs- The inputs (must include both the features and the target).requestedOutputs- The requested outputs which ORT will allocate.pinnedOutputs- The requested outputs which the user has allocated.runOptions- Run options for controlling this specific call.- Returns:
- Requested outputs produced by the training step.
- Throws:
OrtException- If the native call failed.
-
evalStep
Performs a single evaluation step using the supplied inputs.- Parameters:
inputs- The model inputs.- Returns:
- All model outputs.
- Throws:
OrtException- If the native call failed.
-
evalStep
public OrtSession.Result evalStep(Map<String, ? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtExceptionPerforms a single evaluation step using the supplied inputs.- Parameters:
inputs- The model inputs.runOptions- Run options for controlling this specific call.- Returns:
- All model outputs.
- Throws:
OrtException- If the native call failed.
-
evalStep
public OrtSession.Result evalStep(Map<String, ? extends OnnxTensorLike> inputs, Set<String> requestedOutputs) throws OrtExceptionPerforms a single evaluation step using the supplied inputs.- Parameters:
inputs- The model inputs.requestedOutputs- The requested output names.- Returns:
- The requested outputs.
- Throws:
OrtException- If the native call failed.
-
evalStep
public OrtSession.Result evalStep(Map<String, ? extends OnnxTensorLike> inputs, Map<String, throws OrtException? extends OnnxValue> pinnedOutputs) Performs a single evaluation step using the supplied inputs.The outputs are sorted based on the supplied map traversal order.
Note: pinned outputs are not owned by the
OrtSession.Resultobject, and are not closed when the result object is closed.- Parameters:
inputs- The inputs to score.pinnedOutputs- The requested outputs which the user has allocated.- Returns:
- The requested outputs.
- Throws:
OrtException- If the native call failed.
-
evalStep
public OrtSession.Result evalStep(Map<String, ? extends OnnxTensorLike> inputs, Set<String> requestedOutputs, Map<String, throws OrtException? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) Performs a single evaluation step using the supplied inputs.The outputs are sorted based on the supplied set traversal order with pinned outputs first, then requested outputs. An
IllegalArgumentExceptionis thrown if the same output name appears in both the requested outputs and the pinned outputs.Note: pinned outputs are not owned by the
OrtSession.Resultobject, and are not closed when the result object is closed.- Parameters:
inputs- The inputs to score.requestedOutputs- The requested outputs which ORT will allocate.pinnedOutputs- The requested outputs which the user has allocated.runOptions- Run options for controlling this specific call.- Returns:
- The requested outputs.
- Throws:
OrtException- If the native call failed.
-
setLearningRate
Sets the learning rate for the training session.Should be used only when there is no learning rate scheduler in the session. Not used to set the initial learning rate for LR schedulers.
- Parameters:
learningRate- The learning rate.- Throws:
OrtException- If the call failed.
-
getLearningRate
Gets the current learning rate for this training session.- Returns:
- The current learning rate.
- Throws:
OrtException- If the call failed.
-
optimizerStep
Applies the gradient updates to the trainable parameters using the optimizer model.- Throws:
OrtException- If the native call failed.
-
optimizerStep
Applies the gradient updates to the trainable parameters using the optimizer model.The run options can be used to control logging and to terminate the call early.
- Parameters:
runOptions- Options for controlling the model execution.- Throws:
OrtException- If the native call failed.
-
registerLinearLRScheduler
public void registerLinearLRScheduler(long warmupSteps, long totalSteps, float initialLearningRate) throws OrtException Registers a linear learning rate scheduler with linear warmup.- Parameters:
warmupSteps- The number of steps to increase the learning rate from zero toinitialLearningRate.totalSteps- The total number of steps this scheduler operates over.initialLearningRate- The maximum learning rate.- Throws:
OrtException- If the native call failed.
-
schedulerStep
Updates the learning rate based on the registered learning rate scheduler.- Throws:
OrtException- If the native call failed.
-
exportModelForInference
Exports the evaluation model as a model suitable for inference, setting the desired nodes as output nodes.Note that this method reloads the evaluation model from the path provided to the training session, and this path must still be valid.
- Parameters:
outputPath- The path to write out the inference model.outputNames- The names of the output nodes.- Throws:
OrtException- If the native call failed.
-