public class Accuracy extends TrainingMetric
Accuracy is a TrainingMetric that computes the accuracy score.
The accuracy score is defined as \(accuracy(y, \hat{y}) = \frac{1}{n}\sum_{i=0}^{n-1}1(\hat{y_i} == y_i)\)
| Constructor and Description |
|---|
Accuracy()
Creates an accuracy metric that computes accuracy across axis 1 along the 0th index.
|
Accuracy(java.lang.String name,
int index)
Creates an accuracy metric that computes accuracy across axis 1 along given index.
|
Accuracy(java.lang.String name,
int index,
int axis)
Creates an accuracy metric.
|
| Modifier and Type | Method and Description |
|---|---|
void |
addCorrectInstances(long numInstances)
Add a number to the correct instances.
|
void |
addTotalInstances(long totalInstances)
Add a number to the total instances.
|
float |
getValue()
Calculates metric values.
|
void |
reset()
Resets metric values.
|
void |
update(NDArray labels,
NDArray predictions)
Computes and updates the accuracy based on the labels and predictions.
|
void |
update(NDList labels,
NDList predictions)
Computes and updates the training metrics based on the labels and predictions.
|
checkLabelShapes, checkLabelShapes, duplicate, getNamepublic Accuracy(java.lang.String name,
int index,
int axis)
name - the name of the metric, default is "Accuracy"index - the index of the NDArray in labels to compute accuracy foraxis - the axis that represent classes in prediction, default 1public Accuracy()
public Accuracy(java.lang.String name,
int index)
name - the name of the metric, default is "Accuracy"index - the index of the NDArray in labels to compute accuracy forpublic void reset()
reset in class TrainingMetricpublic void update(NDArray labels, NDArray predictions)
public void update(NDList labels, NDList predictions)
update in class TrainingMetriclabels - a NDList of labelspredictions - a NDList of predictionspublic float getValue()
getValue in class TrainingMetricPair of metric name and valuepublic void addCorrectInstances(long numInstances)
numInstances - the number to increment bypublic void addTotalInstances(long totalInstances)
totalInstances - the number to increment by