Class ROCBinary
- java.lang.Object
-
- org.nd4j.evaluation.BaseEvaluation<ROCBinary>
-
- org.nd4j.evaluation.classification.ROCBinary
-
- All Implemented Interfaces:
Serializable,IEvaluation<ROCBinary>
public class ROCBinary extends BaseEvaluation<ROCBinary>
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classROCBinary.MetricAUROC: Area under ROC curve
AUPRC: Area under Precision-Recall Curve
-
Field Summary
Fields Modifier and Type Field Description protected intaxisstatic intDEFAULT_STATS_PRECISION
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description doublecalculateAUC(int outputNum)Calculate the AUC - Area Under (ROC) Curve
Utilizes trapezoidal integration internallydoublecalculateAUCPR(int outputNum)Calculate the AUCPR - Area Under Curve - Precision Recall
Utilizes trapezoidal integration internallydoublecalculateAverageAuc()Macro-average AUC for all outcomesdoublecalculateAverageAUCPR()voideval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData)static ROCBinaryfromJson(String json)intgetAxis()Get the axis - seesetAxis(int)for detailslonggetCountActualNegative(int outputNum)Get the actual negative count (accounting for any masking) for the specified output/columnlonggetCountActualPositive(int outputNum)Get the actual positive count (accounting for any masking) for the specified output/columnPrecisionRecallCurvegetPrecisionRecallCurve(int outputNum)Get the Precision-Recall curve for the specified outputROCgetROC(int outputNum)Get the ROC object for the specific columnRocCurvegetRocCurve(int outputNum)Get the ROC curve for the specified outputdoublegetValue(IMetric metric)Get the value of a given metric for this evaluation.voidmerge(ROCBinary other)ROCBinarynewInstance()Get a new instance of this evaluation, with the same configuration but no data.intnumLabels()Returns the number of labels - (i.e., size of the prediction/labels arrays) - if known.voidreset()doublescoreForMetric(ROCBinary.Metric metric, int idx)voidsetAxis(int axis)Set the axis for evaluation - this is the dimension along which the probability (and label independent binary classes) are present.
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3voidsetLabelNames(List<String> labels)Set the label names, for printing viastats()Stringstats()Stringstats(int printPrecision)-
Methods inherited from class org.nd4j.evaluation.BaseEvaluation
attempFromLegacyFromJson, eval, eval, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, reshapeAndExtractNotMasked, toJson, toString, toYaml
-
-
-
-
Field Detail
-
DEFAULT_STATS_PRECISION
public static final int DEFAULT_STATS_PRECISION
- See Also:
- Constant Field Values
-
axis
protected int axis
-
-
Constructor Detail
-
ROCBinary
protected ROCBinary(int axis, int thresholdSteps, boolean rocRemoveRedundantPts, List<String> labels)
-
ROCBinary
public ROCBinary()
-
ROCBinary
public ROCBinary(int thresholdSteps)
- Parameters:
thresholdSteps- Number of threshold steps to use for the ROC calculation. Set to 0 for exact ROC calculation
-
ROCBinary
public ROCBinary(int thresholdSteps, boolean rocRemoveRedundantPts)- Parameters:
thresholdSteps- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationrocRemoveRedundantPts- Usually set to true. If true, remove any redundant points from ROC and P-R curves
-
-
Method Detail
-
setAxis
public void setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label independent binary classes) are present.
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3- Parameters:
axis- Axis to use for evaluation
-
getAxis
public int getAxis()
Get the axis - seesetAxis(int)for details
-
reset
public void reset()
-
eval
public void eval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData)
-
merge
public void merge(ROCBinary other)
-
numLabels
public int numLabels()
Returns the number of labels - (i.e., size of the prediction/labels arrays) - if known. Returns -1 otherwise
-
getCountActualPositive
public long getCountActualPositive(int outputNum)
Get the actual positive count (accounting for any masking) for the specified output/column- Parameters:
outputNum- Index of the output (0 tonumLabels()-1)
-
getCountActualNegative
public long getCountActualNegative(int outputNum)
Get the actual negative count (accounting for any masking) for the specified output/column- Parameters:
outputNum- Index of the output (0 tonumLabels()-1)
-
getROC
public ROC getROC(int outputNum)
Get the ROC object for the specific column- Parameters:
outputNum- Column (output number)- Returns:
- The underlying ROC object for this specific column
-
getRocCurve
public RocCurve getRocCurve(int outputNum)
Get the ROC curve for the specified output- Parameters:
outputNum- Number of the output to get the ROC curve for- Returns:
- ROC curve
-
getPrecisionRecallCurve
public PrecisionRecallCurve getPrecisionRecallCurve(int outputNum)
Get the Precision-Recall curve for the specified output- Parameters:
outputNum- Number of the output to get the P-R curve for- Returns:
- Precision recall curve
-
calculateAverageAuc
public double calculateAverageAuc()
Macro-average AUC for all outcomes- Returns:
- the (macro-)average AUC for all outcomes.
-
calculateAverageAUCPR
public double calculateAverageAUCPR()
- Returns:
- the (macro-)average AUPRC (area under precision recall curve)
-
calculateAUC
public double calculateAUC(int outputNum)
Calculate the AUC - Area Under (ROC) Curve
Utilizes trapezoidal integration internally- Parameters:
outputNum- Output number to calculate AUC for- Returns:
- AUC
-
calculateAUCPR
public double calculateAUCPR(int outputNum)
Calculate the AUCPR - Area Under Curve - Precision Recall
Utilizes trapezoidal integration internally- Parameters:
outputNum- Output number to calculate AUCPR for- Returns:
- AUCPR
-
setLabelNames
public void setLabelNames(List<String> labels)
Set the label names, for printing viastats()
-
stats
public String stats()
- Returns:
-
stats
public String stats(int printPrecision)
-
scoreForMetric
public double scoreForMetric(ROCBinary.Metric metric, int idx)
-
getValue
public double getValue(IMetric metric)
Description copied from interface:IEvaluationGet the value of a given metric for this evaluation.
-
newInstance
public ROCBinary newInstance()
Description copied from interface:IEvaluationGet a new instance of this evaluation, with the same configuration but no data.
-
-