mindspore.nn.ConfusionMatrix
- class mindspore.nn.ConfusionMatrix(num_classes, normalize='no_norm', threshold=0.5)[源代码]
计算混淆矩阵(confusion matrix),通常用于评估分类模型的性能,包括二分类和多分类场景。
如果您只想使用混淆矩阵,请使用该类。如果想计算”PPV”、”TPR”、”TNR”等,请使用’mindspore.nn.ConfusionMatrixMetric’类。
参数:
num_classes (int) - 数据集中的类别数量。
normalize (str) - 计算ConfsMatrix的参数支持四种归一化模式,默认值:None。
“no_norm” (None):不使用标准化。
“target” (str):基于目标值的标准化。
“prediction” (str):基于预测值的标准化。
“all” (str):整个矩阵的标准化。
threshold (float) - 阈值,用于与输入Tensor进行比较。默认值:0.5。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import numpy as np >>> from mindspore import nn, Tensor >>> >>> x = Tensor(np.array([1, 0, 1, 0])) >>> y = Tensor(np.array([1, 0, 0, 1])) >>> metric = nn.ConfusionMatrix(num_classes=2, normalize='no_norm', threshold=0.5) >>> metric.clear() >>> metric.update(x, y) >>> output = metric.eval() >>> print(output) [[1. 1.] [1. 1.]]