Class SeqBatcher


  • public class SeqBatcher
    extends java.lang.Object
    SeqBatcher stores the search state (BatchTensorList), the control variables (e.g. seqLength, offSets, etc), and batch operations (merge, trim, exitCriteria, etc) on BatchTensorList.
    • Method Summary

      All Methods Instance Methods Concrete Methods 
      Modifier and Type Method Description
      void addBatch​(SeqBatcher seqBatcherNew)
      Adds new batch.
      java.util.Map<java.lang.Long,​NDArray> collectAndTrim()
      Collects the finished sequences and trim the left padding.
      void exitCriteria​(NDArray outputIds, long maxLength, long eosTokenId)
      Checks which batch needs to exit, according certain criteria like EOS or maxLength.
      BatchTensorList getData()
      Returns the batch data which is stored as a BatchTensorList.
      boolean sequenceComplete()
      Computes the position ids by linear search from the left.
      • Methods inherited from class java.lang.Object

        clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
    • Method Detail

      • getData

        public BatchTensorList getData()
        Returns the batch data which is stored as a BatchTensorList.
        Returns:
        the batch data stored as BatchTensorList
      • addBatch

        public void addBatch​(SeqBatcher seqBatcherNew)
        Adds new batch.

        Modify the batch dimension and the left padding.

        Parameters:
        seqBatcherNew - the seqBatcher to add.
      • exitCriteria

        public void exitCriteria​(NDArray outputIds,
                                 long maxLength,
                                 long eosTokenId)
        Checks which batch needs to exit, according certain criteria like EOS or maxLength.

        It is an iteration over batch and is thus also considered as batch operation.

        Parameters:
        outputIds - output token ids in an incremental forward call
        maxLength - max total sequence length
        eosTokenId - end of sentence token id
      • collectAndTrim

        public java.util.Map<java.lang.Long,​NDArray> collectAndTrim()
        Collects the finished sequences and trim the left padding.
        Returns:
        a map that stores request id to output token ids
      • sequenceComplete

        public boolean sequenceComplete()
        Computes the position ids by linear search from the left.
        Returns:
        the boolean indicating whether all sequences are empty