mindspore.train.ConfusionMatrix
- class mindspore.train.ConfusionMatrix(num_classes, normalize='no_norm', threshold=0.5)[源代码]
计算混淆矩阵(confusion matrix),通常用于评估分类模型的性能,包括二分类和多分类场景。
如果只想使用混淆矩阵,请使用该类。如果想计算”PPV”、”TPR”、”TNR”等,请使用
mindspore.train.ConfusionMatrixMetric
类。- 参数:
num_classes (int) - 数据集中的类别数量。
normalize (str) - 计算ConfusionMatrix的参数支持四种归一化模式,默认值:
"no_norm"
。"no_norm"
:不使用标准化。"target"
:基于目标值的标准化。"prediction"
:基于预测值的标准化。"all"
:整个矩阵的标准化。
threshold (float) - 阈值,用于与输入Tensor进行比较。默认值:
0.5
。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import numpy as np >>> from mindspore import Tensor >>> from mindspore.train import ConfusionMatrix >>> >>> x = Tensor(np.array([1, 0, 1, 0])) >>> y = Tensor(np.array([1, 0, 0, 1])) >>> metric = ConfusionMatrix(num_classes=2, normalize='no_norm', threshold=0.5) >>> metric.clear() >>> metric.update(x, y) >>> output = metric.eval() >>> print(output) [[1. 1.] [1. 1.]]