Class BertPretrainingLoss


  • public class BertPretrainingLoss
    extends AbstractCompositeLoss
    Loss that combines the next sentence and masked language losses of bert pretraining.
    • Constructor Detail

      • BertPretrainingLoss

        public BertPretrainingLoss()
        Creates a loss combining the next sentence and masked language loss for bert pretraining.
    • Method Detail

      • inputForComponent

        protected ai.djl.util.Pair<NDList,​NDList> inputForComponent​(int componentIndex,
                                                                          NDList labels,
                                                                          NDList predictions)
        Description copied from class: AbstractCompositeLoss
        Returns the inputs to computing the loss for a component loss.
        Specified by:
        inputForComponent in class AbstractCompositeLoss
        Parameters:
        componentIndex - the index of the component loss
        labels - the label input to the composite loss
        predictions - the predictions input to the composite loss
        Returns:
        a pair of the (labels, predictions) inputs to the component loss
      • getBertNextSentenceLoss

        public BertNextSentenceLoss getBertNextSentenceLoss()
        gets BertNextSentenceLoss.
        Returns:
        BertNextSentenceLoss
      • getBertMaskedLanguageModelLoss

        public BertMaskedLanguageModelLoss getBertMaskedLanguageModelLoss()
        gets BertMaskedLanguageModelLoss.
        Returns:
        BertMaskedLanguageModelLoss