Class Loss
- java.lang.Object
-
- ai.djl.training.evaluator.Evaluator
-
- ai.djl.training.loss.Loss
-
- Direct Known Subclasses:
AbstractCompositeLoss,BertMaskedLanguageModelLoss,BertNextSentenceLoss,ElasticNetWeightDecay,HingeLoss,IndexLoss,L1Loss,L1WeightDecay,L2Loss,L2WeightDecay,MaskedSoftmaxCrossEntropyLoss,QuantileL1Loss,SigmoidBinaryCrossEntropyLoss,SoftmaxCrossEntropyLoss,TabNetClassificationLoss,TabNetRegressionLoss,YOLOv3Loss
public abstract class Loss extends Evaluator
Loss functions (or Cost functions) are used to evaluate the model predictions against true labels for optimization.Although all evaluators can be used to measure the performance of a model, not all of them are suited to being used by an optimizer. Loss functions are usually non-negative where a larger loss represents worse performance. They are also real-valued to accurately compare models.
When creating a loss function, you should avoid having the loss depend on the batch size. For example, if you have a loss per item in a batch and sum those losses, your loss would be
numItemsInBatch*avgLoss. Instead, you should take the mean of those losses to reduce out the batchSize factor. Otherwise, it can make it difficult to tune the learning rate since any change in the batch size would throw it off. If you have a variable batch size, it would be even more difficult.For more details about the class internals, see
Evaluator.
-
-
Field Summary
-
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
-
Constructor Summary
Constructors Constructor Description Loss(java.lang.String name)Base class for metric with abstract update methods.
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description voidaddAccumulator(java.lang.String key)Adds an accumulator for the results of the evaluation with the given key.static ElasticNetWeightDecayelasticNetWeightedDecay(NDList parameters)Returns a new instance ofElasticNetWeightDecaywith default weight and name.static ElasticNetWeightDecayelasticNetWeightedDecay(java.lang.String name, float weight1, float weight2, NDList parameters)Returns a new instance ofElasticNetWeightDecay.static ElasticNetWeightDecayelasticNetWeightedDecay(java.lang.String name, float weight, NDList parameters)Returns a new instance ofElasticNetWeightDecay.static ElasticNetWeightDecayelasticNetWeightedDecay(java.lang.String name, NDList parameters)Returns a new instance ofElasticNetWeightDecaywith default weight.floatgetAccumulator(java.lang.String key)Returns the accumulated evaluator value.static HingeLosshingeLoss()Returns a new instance ofHingeLosswith default arguments.static HingeLosshingeLoss(java.lang.String name)Returns a new instance ofHingeLosswith default arguments.static HingeLosshingeLoss(java.lang.String name, int margin, float weight)Returns a new instance ofHingeLosswith the given arguments.static L1Lossl1Loss()Returns a new instance ofL1Losswith default weight and batch axis.static L1Lossl1Loss(java.lang.String name)Returns a new instance ofL1Losswith default weight and batch axis.static L1Lossl1Loss(java.lang.String name, float weight)Returns a new instance ofL1Losswith given weight.static L1WeightDecayl1WeightedDecay(NDList parameters)Returns a new instance ofL1WeightDecaywith default weight and name.static L1WeightDecayl1WeightedDecay(java.lang.String name, float weight, NDList parameters)Returns a new instance ofL1WeightDecay.static L1WeightDecayl1WeightedDecay(java.lang.String name, NDList parameters)Returns a new instance ofL1WeightDecaywith default weight.static L2Lossl2Loss()Returns a new instance ofL2Losswith default weight and batch axis.static L2Lossl2Loss(java.lang.String name)Returns a new instance ofL2Losswith default weight and batch axis.static L2Lossl2Loss(java.lang.String name, float weight)Returns a new instance ofL2Losswith given weight and batch axis.static L2WeightDecayl2WeightedDecay(NDList parameters)Returns a new instance ofL2WeightDecaywith default weight and name.static L2WeightDecayl2WeightedDecay(java.lang.String name, float weight, NDList parameters)Returns a new instance ofL2WeightDecay.static L2WeightDecayl2WeightedDecay(java.lang.String name, NDList parameters)Returns a new instance ofL2WeightDecaywith default weight.static MaskedSoftmaxCrossEntropyLossmaskedSoftmaxCrossEntropyLoss()Returns a new instance ofMaskedSoftmaxCrossEntropyLosswith default arguments.static MaskedSoftmaxCrossEntropyLossmaskedSoftmaxCrossEntropyLoss(java.lang.String name)Returns a new instance ofMaskedSoftmaxCrossEntropyLosswith default arguments.static MaskedSoftmaxCrossEntropyLossmaskedSoftmaxCrossEntropyLoss(java.lang.String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit)Returns a new instance ofMaskedSoftmaxCrossEntropyLosswith the given arguments.static QuantileL1LossquantileL1Loss(float quantile)Returns a new instance ofQuantileL1Losswith given quantile.static QuantileL1LossquantileL1Loss(java.lang.String name, float quantile)Returns a new instance ofQuantileL1Losswith given quantile.voidresetAccumulator(java.lang.String key)Resets the evaluator value with the given key.static SigmoidBinaryCrossEntropyLosssigmoidBinaryCrossEntropyLoss()Returns a new instance ofSigmoidBinaryCrossEntropyLosswith default arguments.static SigmoidBinaryCrossEntropyLosssigmoidBinaryCrossEntropyLoss(java.lang.String name)Returns a new instance ofSigmoidBinaryCrossEntropyLosswith default arguments.static SigmoidBinaryCrossEntropyLosssigmoidBinaryCrossEntropyLoss(java.lang.String name, float weight, boolean fromSigmoid)Returns a new instance ofSigmoidBinaryCrossEntropyLosswith the given arguments.static SoftmaxCrossEntropyLosssoftmaxCrossEntropyLoss()Returns a new instance ofSoftmaxCrossEntropyLosswith default arguments.static SoftmaxCrossEntropyLosssoftmaxCrossEntropyLoss(java.lang.String name)Returns a new instance ofSoftmaxCrossEntropyLosswith default arguments.static SoftmaxCrossEntropyLosssoftmaxCrossEntropyLoss(java.lang.String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit)Returns a new instance ofSoftmaxCrossEntropyLosswith the given arguments.voidupdateAccumulator(java.lang.String key, NDList labels, NDList predictions)Updates the evaluator with the given key based on aNDListof labels and predictions.-
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, evaluate, getName
-
-
-
-
Method Detail
-
l1Loss
public static L1Loss l1Loss()
Returns a new instance ofL1Losswith default weight and batch axis.- Returns:
- a new instance of
L1Loss
-
l1Loss
public static L1Loss l1Loss(java.lang.String name)
Returns a new instance ofL1Losswith default weight and batch axis.- Parameters:
name- the name of the loss- Returns:
- a new instance of
L1Loss
-
l1Loss
public static L1Loss l1Loss(java.lang.String name, float weight)
Returns a new instance ofL1Losswith given weight.- Parameters:
name- the name of the lossweight- the weight to apply on loss value, default 1- Returns:
- a new instance of
L1Loss
-
quantileL1Loss
public static QuantileL1Loss quantileL1Loss(float quantile)
Returns a new instance ofQuantileL1Losswith given quantile.- Parameters:
quantile- the quantile position of the data to focus on- Returns:
- a new instance of
QuantileL1Loss
-
quantileL1Loss
public static QuantileL1Loss quantileL1Loss(java.lang.String name, float quantile)
Returns a new instance ofQuantileL1Losswith given quantile.- Parameters:
name- the name of the lossquantile- the quantile position of the data to focus on- Returns:
- a new instance of
QuantileL1Loss
-
l2Loss
public static L2Loss l2Loss()
Returns a new instance ofL2Losswith default weight and batch axis.- Returns:
- a new instance of
L2Loss
-
l2Loss
public static L2Loss l2Loss(java.lang.String name)
Returns a new instance ofL2Losswith default weight and batch axis.- Parameters:
name- the name of the loss- Returns:
- a new instance of
L2Loss
-
l2Loss
public static L2Loss l2Loss(java.lang.String name, float weight)
Returns a new instance ofL2Losswith given weight and batch axis.- Parameters:
name- the name of the lossweight- the weight to apply on loss value, default 1- Returns:
- a new instance of
L2Loss
-
sigmoidBinaryCrossEntropyLoss
public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss()
Returns a new instance ofSigmoidBinaryCrossEntropyLosswith default arguments.- Returns:
- a new instance of
SigmoidBinaryCrossEntropyLoss
-
sigmoidBinaryCrossEntropyLoss
public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss(java.lang.String name)
Returns a new instance ofSigmoidBinaryCrossEntropyLosswith default arguments.- Parameters:
name- the name of the loss- Returns:
- a new instance of
SigmoidBinaryCrossEntropyLoss
-
sigmoidBinaryCrossEntropyLoss
public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss(java.lang.String name, float weight, boolean fromSigmoid)
Returns a new instance ofSigmoidBinaryCrossEntropyLosswith the given arguments.- Parameters:
name- the name of the lossweight- the weight to apply on the loss value, default 1fromSigmoid- whether the input is from the output of sigmoid, default false- Returns:
- a new instance of
SigmoidBinaryCrossEntropyLoss
-
softmaxCrossEntropyLoss
public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss()
Returns a new instance ofSoftmaxCrossEntropyLosswith default arguments.- Returns:
- a new instance of
SoftmaxCrossEntropyLoss
-
softmaxCrossEntropyLoss
public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(java.lang.String name)
Returns a new instance ofSoftmaxCrossEntropyLosswith default arguments.- Parameters:
name- the name of the loss- Returns:
- a new instance of
SoftmaxCrossEntropyLoss
-
softmaxCrossEntropyLoss
public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(java.lang.String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit)
Returns a new instance ofSoftmaxCrossEntropyLosswith the given arguments.- Parameters:
name- the name of the lossweight- the weight to apply on the loss value, default 1classAxis- the axis that represents the class probabilities, default -1sparseLabel- whether labels are integer array or probabilities, default truefromLogit- whether labels are log probabilities or un-normalized numbers- Returns:
- a new instance of
SoftmaxCrossEntropyLoss
-
maskedSoftmaxCrossEntropyLoss
public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss()
Returns a new instance ofMaskedSoftmaxCrossEntropyLosswith default arguments.- Returns:
- a new instance of
MaskedSoftmaxCrossEntropyLoss
-
maskedSoftmaxCrossEntropyLoss
public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss(java.lang.String name)
Returns a new instance ofMaskedSoftmaxCrossEntropyLosswith default arguments.- Parameters:
name- the name of the loss- Returns:
- a new instance of
MaskedSoftmaxCrossEntropyLoss
-
maskedSoftmaxCrossEntropyLoss
public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss(java.lang.String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit)
Returns a new instance ofMaskedSoftmaxCrossEntropyLosswith the given arguments.- Parameters:
name- the name of the lossweight- the weight to apply on the loss value, default 1classAxis- the axis that represents the class probabilities, default -1sparseLabel- whether labels are integer array or probabilities, default truefromLogit- whether labels are log probabilities or un-normalized numbers- Returns:
- a new instance of
MaskedSoftmaxCrossEntropyLoss
-
hingeLoss
public static HingeLoss hingeLoss()
Returns a new instance ofHingeLosswith default arguments.- Returns:
- a new instance of
HingeLoss
-
hingeLoss
public static HingeLoss hingeLoss(java.lang.String name)
Returns a new instance ofHingeLosswith default arguments.- Parameters:
name- the name of the loss- Returns:
- a new instance of
HingeLoss
-
hingeLoss
public static HingeLoss hingeLoss(java.lang.String name, int margin, float weight)
Returns a new instance ofHingeLosswith the given arguments.- Parameters:
name- the name of the lossmargin- the margin in hinge loss. Defaults to 1.0weight- the weight to apply on loss value, default 1- Returns:
- a new instance of
HingeLoss
-
l1WeightedDecay
public static L1WeightDecay l1WeightedDecay(NDList parameters)
Returns a new instance ofL1WeightDecaywith default weight and name.- Parameters:
parameters- holds the model weights that will be penalized- Returns:
- a new instance of
L1WeightDecay
-
l1WeightedDecay
public static L1WeightDecay l1WeightedDecay(java.lang.String name, NDList parameters)
Returns a new instance ofL1WeightDecaywith default weight.- Parameters:
name- the name of the weight decayparameters- holds the model weights that will be penalized- Returns:
- a new instance of
L1WeightDecay
-
l1WeightedDecay
public static L1WeightDecay l1WeightedDecay(java.lang.String name, float weight, NDList parameters)
Returns a new instance ofL1WeightDecay.- Parameters:
name- the name of the weight decayweight- the weight to apply on weight decay value, default 1parameters- holds the model weights that will be penalized- Returns:
- a new instance of
L1WeightDecay
-
l2WeightedDecay
public static L2WeightDecay l2WeightedDecay(NDList parameters)
Returns a new instance ofL2WeightDecaywith default weight and name.- Parameters:
parameters- holds the model weights that will be penalized- Returns:
- a new instance of
L2WeightDecay
-
l2WeightedDecay
public static L2WeightDecay l2WeightedDecay(java.lang.String name, NDList parameters)
Returns a new instance ofL2WeightDecaywith default weight.- Parameters:
name- the name of the weight decayparameters- holds the model weights that will be penalized- Returns:
- a new instance of
L2WeightDecay
-
l2WeightedDecay
public static L2WeightDecay l2WeightedDecay(java.lang.String name, float weight, NDList parameters)
Returns a new instance ofL2WeightDecay.- Parameters:
name- the name of the weight decayweight- the weight to apply on weight decay value, default 1parameters- holds the model weights that will be penalized- Returns:
- a new instance of
L2WeightDecay
-
elasticNetWeightedDecay
public static ElasticNetWeightDecay elasticNetWeightedDecay(NDList parameters)
Returns a new instance ofElasticNetWeightDecaywith default weight and name.- Parameters:
parameters- holds the model weights that will be penalized- Returns:
- a new instance of
ElasticNetWeightDecay
-
elasticNetWeightedDecay
public static ElasticNetWeightDecay elasticNetWeightedDecay(java.lang.String name, NDList parameters)
Returns a new instance ofElasticNetWeightDecaywith default weight.- Parameters:
name- the name of the weight decayparameters- holds the model weights that will be penalized- Returns:
- a new instance of
ElasticNetWeightDecay
-
elasticNetWeightedDecay
public static ElasticNetWeightDecay elasticNetWeightedDecay(java.lang.String name, float weight, NDList parameters)
Returns a new instance ofElasticNetWeightDecay.- Parameters:
name- the name of the weight decayweight- the weight to apply on weight decay values, default 1parameters- holds the model weights that will be penalized- Returns:
- a new instance of
ElasticNetWeightDecay
-
elasticNetWeightedDecay
public static ElasticNetWeightDecay elasticNetWeightedDecay(java.lang.String name, float weight1, float weight2, NDList parameters)
Returns a new instance ofElasticNetWeightDecay.- Parameters:
name- the name of the weight decayweight1- the weight to apply on weight decay L1 value, default 1weight2- the weight to apply on weight decay L2 value, default 1parameters- holds the model weights that will be penalized- Returns:
- a new instance of
ElasticNetWeightDecay
-
addAccumulator
public void addAccumulator(java.lang.String key)
Adds an accumulator for the results of the evaluation with the given key.- Specified by:
addAccumulatorin classEvaluator- Parameters:
key- the key for the new accumulator
-
updateAccumulator
public void updateAccumulator(java.lang.String key, NDList labels, NDList predictions)Updates the evaluator with the given key based on aNDListof labels and predictions.This is a synchronized operation. You should only call it at the end of a batch or epoch.
- Specified by:
updateAccumulatorin classEvaluator- Parameters:
key- the key of the accumulator to updatelabels- aNDListof labelspredictions- aNDListof predictions
-
resetAccumulator
public void resetAccumulator(java.lang.String key)
Resets the evaluator value with the given key.- Specified by:
resetAccumulatorin classEvaluator- Parameters:
key- the key of the accumulator to reset
-
getAccumulator
public float getAccumulator(java.lang.String key)
Returns the accumulated evaluator value.- Specified by:
getAccumulatorin classEvaluator- Parameters:
key- the key of the accumulator to get- Returns:
- the accumulated value
-
-