mindspore.train.Metric

class mindspore.train.Metric[源代码]

用于计算评估指标的基类。

在计算评估指标时需要调用 clearupdateeval 三个方法,在继承该类自定义评估指标时,也需要实现这三个方法。其中,update 用于计算中间过程的内部结果,eval 用于计算最终评估结果,clear 用于重置中间结果。 请勿直接使用该类,需使用子类如 mindspore.train.MAEmindspore.train.Recall 等。

支持平台:

Ascend GPU CPU

abstract clear()[源代码]

清除内部评估结果。

说明

所有子类都必须重写此接口。

abstract eval()[源代码]

计算最终评估结果。

说明

所有子类都必须重写此接口。

property indexes

获取当前的 indexes 值。默认为None,调用 set_indexes 方法可修改 indexes 值。

set_indexes(indexes)[源代码]

该接口用于重排 update 的输入。

给定(label0, label1, logits)作为 update 的输入,将 indexes 设置为[2, 1],则最终使用(logits, label1)作为 update 的真实输入。

说明

在继承该类自定义评估函数时,需要用装饰器 mindspore.train.rearrange_inputs 修饰 update 方法,否则配置的 indexes 值不生效。

参数:
  • indexes (List(int)) - logits和标签的目标顺序。

输出:

Metric ,类实例本身。

异常:
  • ValueError - 如果输入的index类型不是list或其元素类型不全为int。

样例:

>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore.train import Accuracy
>>>
>>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
>>> y = Tensor(np.array([1, 0, 1]))
>>> y2 = Tensor(np.array([0, 0, 1]))
>>> metric = Accuracy('classification').set_indexes([0, 2])
>>> metric.clear()
>>> # indexes is [0, 2], using x as logits, y2 as label.
>>> metric.update(x, y, y2)
>>> accuracy = metric.eval()
>>> print(accuracy)
0.3333333333333333
abstract update(*inputs)[源代码]

更新内部评估结果。

说明

所有子类都必须重写此接口。

参数:
  • inputs - 可变长度输入参数列表。通常是预测值和对应的真实标签。