Package ai.djl.training.loss
Class TabNetRegressionLoss
- java.lang.Object
-
- ai.djl.training.evaluator.Evaluator
-
- ai.djl.training.loss.Loss
-
- ai.djl.training.loss.TabNetRegressionLoss
-
public class TabNetRegressionLoss extends Loss
Calculates the loss of tabNet for regression tasks.Actually, tabNet is not only used for Supervised Learning, it's also widely used in unsupervised learning. For unsupervised learning, it should come from the decoder(aka attentionTransformer of tabNet)
-
-
Field Summary
-
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
-
Constructor Summary
Constructors Constructor Description TabNetRegressionLoss()Calculates the loss of a TabNet instance for regression tasks.TabNetRegressionLoss(java.lang.String name)Calculates the loss of a TabNet instance for regression tasks.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NDArrayevaluate(NDList labels, NDList predictions)Calculates the evaluation between the labels and the predictions.-
Methods inherited from class ai.djl.training.loss.Loss
addAccumulator, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, getAccumulator, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, quantileL1Loss, quantileL1Loss, resetAccumulator, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, updateAccumulator
-
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
-
-
-
Constructor Detail
-
TabNetRegressionLoss
public TabNetRegressionLoss()
Calculates the loss of a TabNet instance for regression tasks.
-
TabNetRegressionLoss
public TabNetRegressionLoss(java.lang.String name)
Calculates the loss of a TabNet instance for regression tasks.- Parameters:
name- the name of the loss function
-
-