Class ROCMultiClass
- java.lang.Object
-
- org.nd4j.evaluation.BaseEvaluation<ROCMultiClass>
-
- org.nd4j.evaluation.classification.ROCMultiClass
-
- All Implemented Interfaces:
Serializable,IEvaluation<ROCMultiClass>
public class ROCMultiClass extends BaseEvaluation<ROCMultiClass>
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classROCMultiClass.MetricAUROC: Area under ROC curve
AUPRC: Area under Precision-Recall Curve
-
Field Summary
Fields Modifier and Type Field Description protected intaxisstatic intDEFAULT_STATS_PRECISION
-
Constructor Summary
Constructors Modifier Constructor Description ROCMultiClass()ROCMultiClass(int thresholdSteps)ROCMultiClass(int thresholdSteps, boolean rocRemoveRedundantPts)protectedROCMultiClass(int axis, int thresholdSteps, boolean rocRemoveRedundantPts, List<String> labels)
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description doublecalculateAUC(int classIdx)Calculate the AUC - Area Under ROC Curve
Utilizes trapezoidal integration internallydoublecalculateAUCPR(int classIdx)Calculate the AUPRC - Area Under Curve Precision Recall
Utilizes trapezoidal integration internallydoublecalculateAverageAUC()Calculate the macro-average (one-vs-all) AUC for all classesdoublecalculateAverageAUCPR()Calculate the macro-average (one-vs-all) AUCPR (area under precision recall curve) for all classesvoideval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData)Evaluate the network, with optional metadatastatic ROCMultiClassfromJson(String json)intgetAxis()Get the axis - seesetAxis(int)for detailslonggetCountActualNegative(int outputNum)Get the actual negative count (accounting for any masking) for the specified output/columnlonggetCountActualPositive(int outputNum)Get the actual positive count (accounting for any masking) for the specified classintgetNumClasses()PrecisionRecallCurvegetPrecisionRecallCurve(int classIdx)Get the (one vs.RocCurvegetRocCurve(int classIdx)Get the (one vs.doublegetValue(IMetric metric)Get the value of a given metric for this evaluation.voidmerge(ROCMultiClass other)Merge this ROCMultiClass instance with another.ROCMultiClassnewInstance()Get a new instance of this evaluation, with the same configuration but no data.voidreset()doublescoreForMetric(ROCMultiClass.Metric metric, int idx)voidsetAxis(int axis)Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3Stringstats()Stringstats(int printPrecision)-
Methods inherited from class org.nd4j.evaluation.BaseEvaluation
attempFromLegacyFromJson, eval, eval, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, reshapeAndExtractNotMasked, toJson, toString, toYaml
-
-
-
-
Field Detail
-
DEFAULT_STATS_PRECISION
public static final int DEFAULT_STATS_PRECISION
- See Also:
- Constant Field Values
-
axis
protected int axis
-
-
Constructor Detail
-
ROCMultiClass
protected ROCMultiClass(int axis, int thresholdSteps, boolean rocRemoveRedundantPts, List<String> labels)
-
ROCMultiClass
public ROCMultiClass()
-
ROCMultiClass
public ROCMultiClass(int thresholdSteps)
- Parameters:
thresholdSteps- Number of threshold steps to use for the ROC calculation. Set to 0 for exact ROC calculation
-
ROCMultiClass
public ROCMultiClass(int thresholdSteps, boolean rocRemoveRedundantPts)- Parameters:
thresholdSteps- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationrocRemoveRedundantPts- Usually set to true. If true, remove any redundant points from ROC and P-R curves
-
-
Method Detail
-
setAxis
public void setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3- Parameters:
axis- Axis to use for evaluation
-
getAxis
public int getAxis()
Get the axis - seesetAxis(int)for details
-
reset
public void reset()
-
stats
public String stats()
- Returns:
-
stats
public String stats(int printPrecision)
-
eval
public void eval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData)
Evaluate the network, with optional metadata- Parameters:
labels- Data labelspredictions- Network predictionsrecordMetaData- Optional; may be null. If not null, should have size equal to the number of outcomes/guesses
-
getRocCurve
public RocCurve getRocCurve(int classIdx)
Get the (one vs. all) ROC curve for the specified class- Parameters:
classIdx- Class index to get the ROC curve for- Returns:
- ROC curve for the given class
-
getPrecisionRecallCurve
public PrecisionRecallCurve getPrecisionRecallCurve(int classIdx)
Get the (one vs. all) Precision-Recall curve for the specified class- Parameters:
classIdx- Class to get the P-R curve for- Returns:
- Precision recall curve for the given class
-
calculateAUC
public double calculateAUC(int classIdx)
Calculate the AUC - Area Under ROC Curve
Utilizes trapezoidal integration internally- Returns:
- AUC
-
calculateAUCPR
public double calculateAUCPR(int classIdx)
Calculate the AUPRC - Area Under Curve Precision Recall
Utilizes trapezoidal integration internally- Returns:
- AUC
-
calculateAverageAUC
public double calculateAverageAUC()
Calculate the macro-average (one-vs-all) AUC for all classes
-
calculateAverageAUCPR
public double calculateAverageAUCPR()
Calculate the macro-average (one-vs-all) AUCPR (area under precision recall curve) for all classes
-
getCountActualPositive
public long getCountActualPositive(int outputNum)
Get the actual positive count (accounting for any masking) for the specified class- Parameters:
outputNum- Index of the class
-
getCountActualNegative
public long getCountActualNegative(int outputNum)
Get the actual negative count (accounting for any masking) for the specified output/column- Parameters:
outputNum- Index of the class
-
merge
public void merge(ROCMultiClass other)
Merge this ROCMultiClass instance with another. This ROCMultiClass instance is modified, by adding the stats from the other instance.- Parameters:
other- ROCMultiClass instance to combine with this one
-
getNumClasses
public int getNumClasses()
-
fromJson
public static ROCMultiClass fromJson(String json)
-
scoreForMetric
public double scoreForMetric(ROCMultiClass.Metric metric, int idx)
-
getValue
public double getValue(IMetric metric)
Description copied from interface:IEvaluationGet the value of a given metric for this evaluation.
-
newInstance
public ROCMultiClass newInstance()
Description copied from interface:IEvaluationGet a new instance of this evaluation, with the same configuration but no data.
-
-