Class EvaluationCalibration
- java.lang.Object
-
- org.nd4j.evaluation.BaseEvaluation<EvaluationCalibration>
-
- org.nd4j.evaluation.classification.EvaluationCalibration
-
- All Implemented Interfaces:
Serializable,IEvaluation<EvaluationCalibration>
public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration>
- See Also:
- Serialized Form
-
-
Field Summary
Fields Modifier and Type Field Description protected intaxisstatic intDEFAULT_HISTOGRAM_NUM_BINSstatic intDEFAULT_RELIABILITY_DIAG_NUM_BINS
-
Constructor Summary
Constructors Modifier Constructor Description EvaluationCalibration()Create an EvaluationCalibration instance with the default number of binsEvaluationCalibration(int reliabilityDiagNumBins, int histogramNumBins)Create an EvaluationCalibration instance with the specified number of binsEvaluationCalibration(int reliabilityDiagNumBins, int histogramNumBins, boolean excludeEmptyBins)Create an EvaluationCalibration instance with the specified number of binsprotectedEvaluationCalibration(int axis, int reliabilityDiagNumBins, int histogramNumBins, boolean excludeEmptyBins)
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description voideval(INDArray labels, INDArray networkPredictions)voideval(INDArray labels, INDArray predictions, INDArray mask)voideval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData)static EvaluationCalibrationfromJson(String json)intgetAxis()Get the axis - seesetAxis(int)for detailsint[]getLabelCountsEachClass()int[]getPredictionCountsEachClass()HistogramgetProbabilityHistogram(int labelClassIdx)Return a probability histogram of the specified label class index.HistogramgetProbabilityHistogramAllClasses()Return a probability histogram for all predictions/classes.ReliabilityDiagramgetReliabilityDiagram(int classIdx)Get the reliability diagram for the specified classHistogramgetResidualPlot(int labelClassIdx)Get the residual plot, only for examples of the specified class..HistogramgetResidualPlotAllClasses()Get the residual plot for all classes combined.doublegetValue(IMetric metric)Get the value of a given metric for this evaluation.voidmerge(EvaluationCalibration other)EvaluationCalibrationnewInstance()Get a new instance of this evaluation, with the same configuration but no data.intnumClasses()voidreset()voidsetAxis(int axis)Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3Stringstats()-
Methods inherited from class org.nd4j.evaluation.BaseEvaluation
attempFromLegacyFromJson, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, reshapeAndExtractNotMasked, toJson, toString, toYaml
-
-
-
-
Field Detail
-
DEFAULT_RELIABILITY_DIAG_NUM_BINS
public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS
- See Also:
- Constant Field Values
-
DEFAULT_HISTOGRAM_NUM_BINS
public static final int DEFAULT_HISTOGRAM_NUM_BINS
- See Also:
- Constant Field Values
-
axis
protected int axis
-
-
Constructor Detail
-
EvaluationCalibration
protected EvaluationCalibration(int axis, int reliabilityDiagNumBins, int histogramNumBins, boolean excludeEmptyBins)
-
EvaluationCalibration
public EvaluationCalibration()
Create an EvaluationCalibration instance with the default number of bins
-
EvaluationCalibration
public EvaluationCalibration(int reliabilityDiagNumBins, int histogramNumBins)Create an EvaluationCalibration instance with the specified number of bins- Parameters:
reliabilityDiagNumBins- Number of bins for the reliability diagram (usually 10)histogramNumBins- Number of bins for the histograms
-
EvaluationCalibration
public EvaluationCalibration(int reliabilityDiagNumBins, int histogramNumBins, boolean excludeEmptyBins)Create an EvaluationCalibration instance with the specified number of bins- Parameters:
reliabilityDiagNumBins- Number of bins for the reliability diagram (usually 10)histogramNumBins- Number of bins for the histogramsexcludeEmptyBins- For the reliability diagram, whether empty bins should be excluded
-
-
Method Detail
-
setAxis
public void setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3- Parameters:
axis- Axis to use for evaluation
-
getAxis
public int getAxis()
Get the axis - seesetAxis(int)for details
-
eval
public void eval(INDArray labels, INDArray predictions, INDArray mask)
- Specified by:
evalin interfaceIEvaluation<EvaluationCalibration>- Overrides:
evalin classBaseEvaluation<EvaluationCalibration>
-
eval
public void eval(INDArray labels, INDArray networkPredictions)
- Specified by:
evalin interfaceIEvaluation<EvaluationCalibration>- Overrides:
evalin classBaseEvaluation<EvaluationCalibration>
-
eval
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData)
-
merge
public void merge(EvaluationCalibration other)
-
reset
public void reset()
-
stats
public String stats()
- Returns:
-
numClasses
public int numClasses()
-
getReliabilityDiagram
public ReliabilityDiagram getReliabilityDiagram(int classIdx)
Get the reliability diagram for the specified class- Parameters:
classIdx- Index of the class to get the reliability diagram for
-
getLabelCountsEachClass
public int[] getLabelCountsEachClass()
- Returns:
- The number of observed labels for each class. For N classes, be returned array is of length N, with out[i] being the number of labels of class i
-
getPredictionCountsEachClass
public int[] getPredictionCountsEachClass()
- Returns:
- The number of network predictions for each class. For N classes, be returned array is of length N, with out[i] being the number of predicted values (max probability) for class i
-
getResidualPlotAllClasses
public Histogram getResidualPlotAllClasses()
Get the residual plot for all classes combined. The residual plot is defined as a histogram of
|label_i - prob(class_i | input)| for all classes i and examples.
In general, small residuals indicate a superior classifier to large residuals.- Returns:
- Residual plot (histogram) - all predictions/classes
-
getResidualPlot
public Histogram getResidualPlot(int labelClassIdx)
Get the residual plot, only for examples of the specified class.. The residual plot is defined as a histogram of
|label_i - prob(class_i | input)| for all and examples; for this particular method, only predictions where i == labelClassIdx are included.
In general, small residuals indicate a superior classifier to large residuals.- Parameters:
labelClassIdx- Index of the class to get the residual plot for- Returns:
- Residual plot (histogram) - all predictions/classes
-
getProbabilityHistogramAllClasses
public Histogram getProbabilityHistogramAllClasses()
Return a probability histogram for all predictions/classes.- Returns:
- Probability histogram
-
getProbabilityHistogram
public Histogram getProbabilityHistogram(int labelClassIdx)
Return a probability histogram of the specified label class index. That is, for label class index i, a histogram of P(class_i | input) is returned, only for those examples that are labelled as class i.- Parameters:
labelClassIdx- Index of the label class to get the histogram for- Returns:
- Probability histogram
-
fromJson
public static EvaluationCalibration fromJson(String json)
-
getValue
public double getValue(IMetric metric)
Description copied from interface:IEvaluationGet the value of a given metric for this evaluation.
-
newInstance
public EvaluationCalibration newInstance()
Description copied from interface:IEvaluationGet a new instance of this evaluation, with the same configuration but no data.
-
-