public abstract class AbstractAccuracy extends Evaluator
Accuracy is an Evaluator that computes the accuracy score.
The accuracy score is defined as \(accuracy(y, \hat{y}) = \frac{1}{n}\sum_{i=0}^{n-1}1(\hat{y_i} == y_i)\)
| Modifier and Type | Field and Description |
|---|---|
protected int |
axis |
protected java.util.Map<java.lang.String,java.lang.Long> |
correctInstances |
protected int |
index |
totalInstances| Constructor and Description |
|---|
AbstractAccuracy()
Creates an accuracy evaluator that computes accuracy across axis 1 along the 0th index.
|
AbstractAccuracy(java.lang.String name,
int index)
Creates an accuracy evaluator that computes accuracy across axis 1 along given index.
|
AbstractAccuracy(java.lang.String name,
int index,
int axis)
Creates an accuracy evaluator.
|
| Modifier and Type | Method and Description |
|---|---|
protected abstract ai.djl.util.Pair<java.lang.Long,NDArray> |
accuracyHelper(NDList labels,
NDList predictions)
A helper for classes extending
AbstractAccuracy. |
void |
addAccumulator(java.lang.String key)
Adds an accumulator for the results of the evaluation with the given key.
|
NDArray |
evaluate(NDList labels,
NDList predictions)
Calculates the evaluation between the labels and the predictions.
|
float |
getAccumulator(java.lang.String key)
Returns the accumulated evaluator value.
|
void |
resetAccumulator(java.lang.String key)
Resets the evaluator value with the given key.
|
void |
updateAccumulator(java.lang.String key,
NDList labels,
NDList predictions)
Updates the evaluator with the given key based on a
NDList of labels and predictions. |
checkLabelShapes, checkLabelShapes, getNameprotected java.util.Map<java.lang.String,java.lang.Long> correctInstances
protected int axis
protected int index
public AbstractAccuracy()
public AbstractAccuracy(java.lang.String name,
int index)
name - the name of the evaluator, default is "Accuracy"index - the index of the NDArray in labels to compute accuracy forpublic AbstractAccuracy(java.lang.String name,
int index,
int axis)
name - the name of the evaluator, default is "Accuracy"index - the index of the NDArray in labels to compute accuracy foraxis - the axis that represent classes in prediction, default 1protected abstract ai.djl.util.Pair<java.lang.Long,NDArray> accuracyHelper(NDList labels, NDList predictions)
AbstractAccuracy.labels - the labels to get accuracy forpredictions - the predictions to get accuracy forpublic NDArray evaluate(NDList labels, NDList predictions)
Evaluatorpublic void addAccumulator(java.lang.String key)
EvaluatoraddAccumulator in class Evaluatorkey - the key for the new accumulatorpublic void updateAccumulator(java.lang.String key,
NDList labels,
NDList predictions)
EvaluatorNDList of labels and predictions.
This is a synchronized operation. You should only call it at the end of a batch or epoch.
updateAccumulator in class Evaluatorkey - the key of the accumulator to updatelabels - a NDList of labelspredictions - a NDList of predictionspublic void resetAccumulator(java.lang.String key)
EvaluatorresetAccumulator in class Evaluatorkey - the key of the accumulator to resetpublic float getAccumulator(java.lang.String key)
EvaluatorgetAccumulator in class Evaluatorkey - the key of the accumulator to get