Package org.nd4j.evaluation.curves
Class PrecisionRecallCurve
- java.lang.Object
-
- org.nd4j.evaluation.curves.BaseCurve
-
- org.nd4j.evaluation.curves.PrecisionRecallCurve
-
public class PrecisionRecallCurve extends BaseCurve
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classPrecisionRecallCurve.Confusionstatic classPrecisionRecallCurve.Point
-
Field Summary
-
Fields inherited from class org.nd4j.evaluation.curves.BaseCurve
DEFAULT_FORMAT_PREC
-
-
Constructor Summary
Constructors Constructor Description PrecisionRecallCurve(double[] threshold, double[] precision, double[] recall, int[] tpCount, int[] fpCount, int[] fnCount, int totalCount)
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description doublecalculateAUPRC()static PrecisionRecallCurvefromJson(String json)static PrecisionRecallCurvefromYaml(String yaml)PrecisionRecallCurve.ConfusiongetConfusionMatrixAtPoint(int point)Get the binary confusion matrix for the given position.PrecisionRecallCurve.ConfusiongetConfusionMatrixAtThreshold(double threshold)Get the binary confusion matrix for the given threshold.PrecisionRecallCurve.PointgetPointAtPrecision(double precision)Get the point (index, threshold, precision, recall) at the given precision.
Specifically, return the points at the lowest threshold that has precision equal to or greater than the requested precision.PrecisionRecallCurve.PointgetPointAtRecall(double recall)Get the point (index, threshold, precision, recall) at the given recall.
Specifically, return the points at the highest threshold that has recall equal to or greater than the requested recall.PrecisionRecallCurve.PointgetPointAtThreshold(double threshold)Get the point (index, threshold, precision, recall) at the given threshold.
Note that if the threshold is not found exactly, the next highest threshold exceeding the requested threshold is returneddoublegetPrecision(int i)doublegetRecall(int i)doublegetThreshold(int i)StringgetTitle()double[]getX()double[]getY()intnumPoints()-
Methods inherited from class org.nd4j.evaluation.curves.BaseCurve
calculateArea, calculateArea, format, fromJson, fromYaml, toJson, toYaml
-
-
-
-
Method Detail
-
numPoints
public int numPoints()
-
getTitle
public String getTitle()
-
getThreshold
public double getThreshold(int i)
- Parameters:
i- Point number, 0 to numPoints()-1 inclusive- Returns:
- Threshold of a given point
-
getPrecision
public double getPrecision(int i)
- Parameters:
i- Point number, 0 to numPoints()-1 inclusive- Returns:
- Precision of a given point
-
getRecall
public double getRecall(int i)
- Parameters:
i- Point number, 0 to numPoints()-1 inclusive- Returns:
- Recall of a given point
-
calculateAUPRC
public double calculateAUPRC()
- Returns:
- The area under the precision recall curve
-
getPointAtThreshold
public PrecisionRecallCurve.Point getPointAtThreshold(double threshold)
Get the point (index, threshold, precision, recall) at the given threshold.
Note that if the threshold is not found exactly, the next highest threshold exceeding the requested threshold is returned- Parameters:
threshold- Threshold to get the point for- Returns:
- point (index, threshold, precision, recall) at the given threshold
-
getPointAtPrecision
public PrecisionRecallCurve.Point getPointAtPrecision(double precision)
Get the point (index, threshold, precision, recall) at the given precision.
Specifically, return the points at the lowest threshold that has precision equal to or greater than the requested precision.- Parameters:
precision- Precision to get the point for- Returns:
- point (index, threshold, precision, recall) at (or closest exceeding) the given precision
-
getPointAtRecall
public PrecisionRecallCurve.Point getPointAtRecall(double recall)
Get the point (index, threshold, precision, recall) at the given recall.
Specifically, return the points at the highest threshold that has recall equal to or greater than the requested recall.- Parameters:
recall- Recall to get the point for- Returns:
- point (index, threshold, precision, recall) at (or closest exceeding) the given recall
-
getConfusionMatrixAtThreshold
public PrecisionRecallCurve.Confusion getConfusionMatrixAtThreshold(double threshold)
Get the binary confusion matrix for the given threshold. As pergetPointAtThreshold(double), if the threshold is not found exactly, the next highest threshold exceeding the requested threshold is returned- Parameters:
threshold- Threshold at which to get the confusion matrix- Returns:
- Binary confusion matrix
-
getConfusionMatrixAtPoint
public PrecisionRecallCurve.Confusion getConfusionMatrixAtPoint(int point)
Get the binary confusion matrix for the given position. As pergetPointAtThreshold(double).- Parameters:
point- Position at which to get the binary confusion matrix- Returns:
- Binary confusion matrix
-
fromJson
public static PrecisionRecallCurve fromJson(String json)
-
fromYaml
public static PrecisionRecallCurve fromYaml(String yaml)
-
-