Class OrtTrainingSession

java.lang.Object
ai.onnxruntime.OrtTrainingSession
All Implemented Interfaces:
AutoCloseable

public final class OrtTrainingSession extends Object implements AutoCloseable
Wraps an ONNX training model and allows training and inference calls.

Allows the inspection of the model's input and output nodes. Produced by an OrtEnvironment.

Most instance methods throw IllegalStateException if the session is closed and the methods are called.

  • Method Details

    • getTrainInputNames

      public Set<String> getTrainInputNames()
      Returns an ordered set of the train model input names.
      Returns:
      The training inputs.
    • getTrainOutputNames

      public Set<String> getTrainOutputNames()
      Returns an ordered set of the train model output names.
      Returns:
      The training outputs.
    • getEvalInputNames

      public Set<String> getEvalInputNames()
      Returns an ordered set of the eval model input names.
      Returns:
      The evaluation inputs.
    • getEvalOutputNames

      public Set<String> getEvalOutputNames()
      Returns an ordered set of the eval model output names.
      Returns:
      The evaluation outputs.
    • addProperty

      public void addProperty(String name, float value) throws OrtException
      Adds a float property to this training session checkpoint.
      Parameters:
      name - The property name.
      value - The property value.
      Throws:
      OrtException - If the call failed.
    • addProperty

      public void addProperty(String name, int value) throws OrtException
      Adds a int property to this training session checkpoint.
      Parameters:
      name - The property name.
      value - The property value.
      Throws:
      OrtException - If the call failed.
    • addProperty

      public void addProperty(String name, String value) throws OrtException
      Adds a String property to this training session checkpoint.
      Parameters:
      name - The property name.
      value - The property value.
      Throws:
      OrtException - If the call failed.
    • getFloatProperty

      public float getFloatProperty(String name) throws OrtException
      Gets a float property from this training session checkpoint.
      Parameters:
      name - The property name.
      Returns:
      The property value.
      Throws:
      OrtException - If the property does not exist, or is of the wrong type.
    • getIntProperty

      public int getIntProperty(String name) throws OrtException
      Gets a int property from this training session checkpoint.
      Parameters:
      name - The property name.
      Returns:
      The property value.
      Throws:
      OrtException - If the property does not exist, or is of the wrong type.
    • getStringProperty

      public String getStringProperty(String name) throws OrtException
      Gets a String property from this training session checkpoint.
      Parameters:
      name - The property name.
      Returns:
      The property value.
      Throws:
      OrtException - If the property does not exist, or is of the wrong type.
    • close

      public void close()
      Specified by:
      close in interface AutoCloseable
    • saveCheckpoint

      public void saveCheckpoint(Path outputPath, boolean saveOptimizer) throws OrtException
      Save out the training session state into the supplied checkpoint directory.
      Parameters:
      outputPath - Path to a checkpoint directory.
      saveOptimizer - Should the optimizer states be saved out.
      Throws:
      OrtException - If the native call failed.
    • lazyResetGrad

      public void lazyResetGrad() throws OrtException
      Ensures the gradients are reset to zero before the next call to trainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>).

      Note this is a lazy call, the gradients are cleared as part of running the next trainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>) and not before.

      Throws:
      OrtException - If the native call failed.
    • setSeed

      public static void setSeed(long seed) throws OrtException
      Sets the RNG seed used by ONNX Runtime.

      Note this setting is global across OrtTrainingSession instances.

      Parameters:
      seed - The RNG seed.
      Throws:
      OrtException - If the native call failed.
    • trainStep

      public OrtSession.Result trainStep(Map<String,? extends OnnxTensorLike> inputs) throws OrtException
      Performs a single step of training, accumulating the gradients.
      Parameters:
      inputs - The inputs (must include both the features and the target).
      Returns:
      All outputs produced by the training step.
      Throws:
      OrtException - If the native call failed.
    • trainStep

      public OrtSession.Result trainStep(Map<String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtException
      Performs a single step of training, accumulating the gradients.
      Parameters:
      inputs - The inputs (must include both the features and the target).
      runOptions - Run options for controlling this specific call.
      Returns:
      All outputs produced by the training step.
      Throws:
      OrtException - If the native call failed.
    • trainStep

      public OrtSession.Result trainStep(Map<String,? extends OnnxTensorLike> inputs, Set<String> requestedOutputs) throws OrtException
      Performs a single step of training, accumulating the gradients.
      Parameters:
      inputs - The inputs (must include both the features and the target).
      requestedOutputs - The requested outputs.
      Returns:
      Requested outputs produced by the training step.
      Throws:
      OrtException - If the native call failed.
    • trainStep

      public OrtSession.Result trainStep(Map<String,? extends OnnxTensorLike> inputs, Map<String,? extends OnnxValue> pinnedOutputs) throws OrtException
      Performs a single step of training, accumulating the gradients.

      The outputs are sorted based on the supplied map traversal order.

      Note: pinned outputs are not owned by the OrtSession.Result object, and are not closed when the result object is closed.

      Parameters:
      inputs - The inputs (must include both the features and the target).
      pinnedOutputs - The requested outputs which the user has allocated.
      Returns:
      Requested outputs produced by the training step.
      Throws:
      OrtException - If the native call failed.
    • trainStep

      public OrtSession.Result trainStep(Map<String,? extends OnnxTensorLike> inputs, Set<String> requestedOutputs, Map<String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException
      Performs a single step of training, accumulating the gradients.

      The outputs are sorted based on the supplied set traversal order with pinned outputs first, then requested outputs. An IllegalArgumentException is thrown if the same output name appears in both the requested outputs and the pinned outputs.

      Note: pinned outputs are not owned by the OrtSession.Result object, and are not closed when the result object is closed.

      Parameters:
      inputs - The inputs (must include both the features and the target).
      requestedOutputs - The requested outputs which ORT will allocate.
      pinnedOutputs - The requested outputs which the user has allocated.
      runOptions - Run options for controlling this specific call.
      Returns:
      Requested outputs produced by the training step.
      Throws:
      OrtException - If the native call failed.
    • evalStep

      public OrtSession.Result evalStep(Map<String,? extends OnnxTensorLike> inputs) throws OrtException
      Performs a single evaluation step using the supplied inputs.
      Parameters:
      inputs - The model inputs.
      Returns:
      All model outputs.
      Throws:
      OrtException - If the native call failed.
    • evalStep

      public OrtSession.Result evalStep(Map<String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtException
      Performs a single evaluation step using the supplied inputs.
      Parameters:
      inputs - The model inputs.
      runOptions - Run options for controlling this specific call.
      Returns:
      All model outputs.
      Throws:
      OrtException - If the native call failed.
    • evalStep

      public OrtSession.Result evalStep(Map<String,? extends OnnxTensorLike> inputs, Set<String> requestedOutputs) throws OrtException
      Performs a single evaluation step using the supplied inputs.
      Parameters:
      inputs - The model inputs.
      requestedOutputs - The requested output names.
      Returns:
      The requested outputs.
      Throws:
      OrtException - If the native call failed.
    • evalStep

      public OrtSession.Result evalStep(Map<String,? extends OnnxTensorLike> inputs, Map<String,? extends OnnxValue> pinnedOutputs) throws OrtException
      Performs a single evaluation step using the supplied inputs.

      The outputs are sorted based on the supplied map traversal order.

      Note: pinned outputs are not owned by the OrtSession.Result object, and are not closed when the result object is closed.

      Parameters:
      inputs - The inputs to score.
      pinnedOutputs - The requested outputs which the user has allocated.
      Returns:
      The requested outputs.
      Throws:
      OrtException - If the native call failed.
    • evalStep

      public OrtSession.Result evalStep(Map<String,? extends OnnxTensorLike> inputs, Set<String> requestedOutputs, Map<String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException
      Performs a single evaluation step using the supplied inputs.

      The outputs are sorted based on the supplied set traversal order with pinned outputs first, then requested outputs. An IllegalArgumentException is thrown if the same output name appears in both the requested outputs and the pinned outputs.

      Note: pinned outputs are not owned by the OrtSession.Result object, and are not closed when the result object is closed.

      Parameters:
      inputs - The inputs to score.
      requestedOutputs - The requested outputs which ORT will allocate.
      pinnedOutputs - The requested outputs which the user has allocated.
      runOptions - Run options for controlling this specific call.
      Returns:
      The requested outputs.
      Throws:
      OrtException - If the native call failed.
    • setLearningRate

      public void setLearningRate(float learningRate) throws OrtException
      Sets the learning rate for the training session.

      Should be used only when there is no learning rate scheduler in the session. Not used to set the initial learning rate for LR schedulers.

      Parameters:
      learningRate - The learning rate.
      Throws:
      OrtException - If the call failed.
    • getLearningRate

      public float getLearningRate() throws OrtException
      Gets the current learning rate for this training session.
      Returns:
      The current learning rate.
      Throws:
      OrtException - If the call failed.
    • optimizerStep

      public void optimizerStep() throws OrtException
      Applies the gradient updates to the trainable parameters using the optimizer model.
      Throws:
      OrtException - If the native call failed.
    • optimizerStep

      public void optimizerStep(OrtSession.RunOptions runOptions) throws OrtException
      Applies the gradient updates to the trainable parameters using the optimizer model.

      The run options can be used to control logging and to terminate the call early.

      Parameters:
      runOptions - Options for controlling the model execution.
      Throws:
      OrtException - If the native call failed.
    • registerLinearLRScheduler

      public void registerLinearLRScheduler(long warmupSteps, long totalSteps, float initialLearningRate) throws OrtException
      Registers a linear learning rate scheduler with linear warmup.
      Parameters:
      warmupSteps - The number of steps to increase the learning rate from zero to initialLearningRate.
      totalSteps - The total number of steps this scheduler operates over.
      initialLearningRate - The maximum learning rate.
      Throws:
      OrtException - If the native call failed.
    • schedulerStep

      public void schedulerStep() throws OrtException
      Updates the learning rate based on the registered learning rate scheduler.
      Throws:
      OrtException - If the native call failed.
    • exportModelForInference

      public void exportModelForInference(Path outputPath, String[] outputNames) throws OrtException
      Exports the evaluation model as a model suitable for inference, setting the desired nodes as output nodes.

      Note that this method reloads the evaluation model from the path provided to the training session, and this path must still be valid.

      Parameters:
      outputPath - The path to write out the inference model.
      outputNames - The names of the output nodes.
      Throws:
      OrtException - If the native call failed.