Package ai.djl.nn.transformer
Class BertPretrainingLoss
- java.lang.Object
-
- ai.djl.training.evaluator.Evaluator
-
- ai.djl.training.loss.Loss
-
- ai.djl.training.loss.AbstractCompositeLoss
-
- ai.djl.nn.transformer.BertPretrainingLoss
-
public class BertPretrainingLoss extends AbstractCompositeLoss
Loss that combines the next sentence and masked language losses of bert pretraining.
-
-
Field Summary
-
Fields inherited from class ai.djl.training.loss.AbstractCompositeLoss
components
-
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
-
Constructor Summary
Constructors Constructor Description BertPretrainingLoss()Creates a loss combining the next sentence and masked language loss for bert pretraining.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description BertMaskedLanguageModelLossgetBertMaskedLanguageModelLoss()gets BertMaskedLanguageModelLoss.BertNextSentenceLossgetBertNextSentenceLoss()gets BertNextSentenceLoss.protected ai.djl.util.Pair<NDList,NDList>inputForComponent(int componentIndex, NDList labels, NDList predictions)Returns the inputs to computing the loss for a component loss.-
Methods inherited from class ai.djl.training.loss.AbstractCompositeLoss
addAccumulator, evaluate, getAccumulator, getComponents, resetAccumulator, updateAccumulator
-
Methods inherited from class ai.djl.training.loss.Loss
elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, quantileL1Loss, quantileL1Loss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss
-
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
-
-
-
Method Detail
-
inputForComponent
protected ai.djl.util.Pair<NDList,NDList> inputForComponent(int componentIndex, NDList labels, NDList predictions)
Description copied from class:AbstractCompositeLossReturns the inputs to computing the loss for a component loss.- Specified by:
inputForComponentin classAbstractCompositeLoss- Parameters:
componentIndex- the index of the component losslabels- the label input to the composite losspredictions- the predictions input to the composite loss- Returns:
- a pair of the (labels, predictions) inputs to the component loss
-
getBertNextSentenceLoss
public BertNextSentenceLoss getBertNextSentenceLoss()
gets BertNextSentenceLoss.- Returns:
- BertNextSentenceLoss
-
getBertMaskedLanguageModelLoss
public BertMaskedLanguageModelLoss getBertMaskedLanguageModelLoss()
gets BertMaskedLanguageModelLoss.- Returns:
- BertMaskedLanguageModelLoss
-
-