public class BertMaskedLanguageModelLoss extends Loss
totalInstances| Constructor and Description |
|---|
BertMaskedLanguageModelLoss(int labelIdx,
int maskIdx,
int logProbsIdx)
Creates an MLM loss.
|
| Modifier and Type | Method and Description |
|---|---|
NDArray |
accuracy(NDList labels,
NDList predictions)
Calculates the percentage of correctly predicted masked tokens.
|
NDArray |
evaluate(NDList labels,
NDList predictions)
Calculates the evaluation between the labels and the predictions.
|
addAccumulator, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, getAccumulator, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, resetAccumulator, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, updateAccumulatorcheckLabelShapes, checkLabelShapes, getNamepublic BertMaskedLanguageModelLoss(int labelIdx,
int maskIdx,
int logProbsIdx)
labelIdx - index of labelsmaskIdx - index of masklogProbsIdx - index of log probs