mindspore.train.Metric
- class mindspore.train.Metric[源代码]
用于计算评估指标的基类。
在计算评估指标时需要调用 clear 、 update 和 eval 三个方法,在继承该类自定义评估指标时,也需要实现这三个方法。其中,update 用于计算中间过程的内部结果,eval 用于计算最终评估结果,clear 用于重置中间结果。 请勿直接使用该类,需使用子类如
mindspore.train.MAE
、mindspore.train.Recall
等。- 支持平台:
Ascend
GPU
CPU
样例:
>>> import numpy as np >>> import mindspore as ms >>> >>> class MyMAE(ms.train.Metric): ... def __init__(self): ... super(MyMAE, self).__init__() ... self.clear() ... ... def clear(self): ... self._abs_error_sum = 0 ... self._samples_num = 0 ... ... def update(self, *inputs): ... y_pred = inputs[0].asnumpy() ... y = inputs[1].asnumpy() ... abs_error_sum = np.abs(y - y_pred) ... self._abs_error_sum += abs_error_sum.sum() ... self._samples_num += y.shape[0] ... ... def eval(self): ... return self._abs_error_sum / self._samples_num >>> >>> x = ms.Tensor(np.array([[0.1, 0.2, 0.6, 0.9], [0.1, 0.2, 0.6, 0.9]]), ms.float32) >>> y = ms.Tensor(np.array([[0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1]]), ms.float32) >>> y2 = ms.Tensor(np.array([[0.1, 0.25, 0.7, 0.9], [0.1, 0.25, 0.7, 0.9]]), ms.float32) >>> metric = MyMAE().set_indexes([0, 2]) >>> metric.clear() >>> # indexes is [0, 2], using x as logits, y2 as label. >>> metric.update(x, y, y2) >>> accuracy = metric.eval() >>> print(accuracy) 1.399999976158142 >>> print(metric.indexes) [0, 2]
- property indexes
获取当前的 indexes 值。默认为None,调用 set_indexes 方法可修改 indexes 值。
- set_indexes(indexes)[源代码]
该接口用于重排 update 的输入。
给定(label0, label1, logits)作为 update 的输入,将 indexes 设置为[2, 1],则最终使用(logits, label1)作为 update 的真实输入。
说明
在继承该类自定义评估函数时,需要用装饰器 mindspore.train.rearrange_inputs 修饰 update 方法,否则配置的 indexes 值不生效。
- 参数:
indexes (List(int)) - logits和标签的目标顺序。
- 输出:
Metric
,类实例本身。- 异常:
ValueError - 如果输入的index类型不是list或其元素类型不全为int。