Package ai.djl.training.loss
Class AbstractCompositeLoss
- java.lang.Object
-
- ai.djl.training.evaluator.Evaluator
-
- ai.djl.training.loss.Loss
-
- ai.djl.training.loss.AbstractCompositeLoss
-
- Direct Known Subclasses:
BertPretrainingLoss,SimpleCompositeLoss,SingleShotDetectionLoss
public abstract class AbstractCompositeLoss extends Loss
AbstractCompositeLossis aLossclass that can combine otherLosses together to make a larger loss.The AbstractCompositeLoss is designed to be extended for more complicated composite losses. For simpler use cases, consider using the
SimpleCompositeLoss.
-
-
Field Summary
Fields Modifier and Type Field Description protected java.util.List<Loss>components-
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
-
Constructor Summary
Constructors Constructor Description AbstractCompositeLoss(java.lang.String name)Constructs a composite loss with the given name.
-
Method Summary
All Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description voidaddAccumulator(java.lang.String key)Adds an accumulator for the results of the evaluation with the given key.NDArrayevaluate(NDList labels, NDList predictions)Calculates the evaluation between the labels and the predictions.floatgetAccumulator(java.lang.String key)Returns the accumulated evaluator value.java.util.List<Loss>getComponents()Returns the component losses that make up the composite loss.protected abstract ai.djl.util.Pair<NDList,NDList>inputForComponent(int componentIndex, NDList labels, NDList predictions)Returns the inputs to computing the loss for a component loss.voidresetAccumulator(java.lang.String key)Resets the evaluator value with the given key.voidupdateAccumulator(java.lang.String key, NDList labels, NDList predictions)Updates the evaluator with the given key based on aNDListof labels and predictions.-
Methods inherited from class ai.djl.training.loss.Loss
elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, quantileL1Loss, quantileL1Loss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss
-
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
-
-
-
Field Detail
-
components
protected java.util.List<Loss> components
-
-
Method Detail
-
inputForComponent
protected abstract ai.djl.util.Pair<NDList,NDList> inputForComponent(int componentIndex, NDList labels, NDList predictions)
Returns the inputs to computing the loss for a component loss.- Parameters:
componentIndex- the index of the component losslabels- the label input to the composite losspredictions- the predictions input to the composite loss- Returns:
- a pair of the (labels, predictions) inputs to the component loss
-
getComponents
public java.util.List<Loss> getComponents()
Returns the component losses that make up the composite loss.- Returns:
- the component losses that make up the composite loss
-
evaluate
public NDArray evaluate(NDList labels, NDList predictions)
Calculates the evaluation between the labels and the predictions.
-
addAccumulator
public void addAccumulator(java.lang.String key)
Adds an accumulator for the results of the evaluation with the given key.- Overrides:
addAccumulatorin classLoss- Parameters:
key- the key for the new accumulator
-
updateAccumulator
public void updateAccumulator(java.lang.String key, NDList labels, NDList predictions)Updates the evaluator with the given key based on aNDListof labels and predictions.This is a synchronized operation. You should only call it at the end of a batch or epoch.
- Overrides:
updateAccumulatorin classLoss- Parameters:
key- the key of the accumulator to updatelabels- aNDListof labelspredictions- aNDListof predictions
-
resetAccumulator
public void resetAccumulator(java.lang.String key)
Resets the evaluator value with the given key.- Overrides:
resetAccumulatorin classLoss- Parameters:
key- the key of the accumulator to reset
-
getAccumulator
public float getAccumulator(java.lang.String key)
Returns the accumulated evaluator value.- Overrides:
getAccumulatorin classLoss- Parameters:
key- the key of the accumulator to get- Returns:
- the accumulated value
-
-