mindspore.train.Metric

View Source On Gitee
class mindspore.train.Metric[source]

Base class of metric, which is used to evaluate metrics.

The clear, update, and eval should be called when evaluating metric, and they should be overridden by subclasse. update will accumulate intermediate results in the evaluation process, eval will evaluate the final result, and clear will reinitialize the intermediate results.

Never use this class directly, but instantiate one of its subclasses instead, for examples, mindspore.train.MAE, mindspore.train.Recall etc.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore as ms
>>>
>>> class MyMAE(ms.train.Metric):
...     def __init__(self):
...         super(MyMAE, self).__init__()
...         self.clear()
...
...     def clear(self):
...         self._abs_error_sum = 0
...         self._samples_num = 0
...
...     def update(self, *inputs):
...         y_pred = inputs[0].asnumpy()
...         y = inputs[1].asnumpy()
...         abs_error_sum = np.abs(y - y_pred)
...         self._abs_error_sum += abs_error_sum.sum()
...         self._samples_num += y.shape[0]
...
...     def eval(self):
...         return self._abs_error_sum / self._samples_num
>>>
>>> x = ms.Tensor(np.array([[0.1, 0.2, 0.6, 0.9], [0.1, 0.2, 0.6, 0.9]]), ms.float32)
>>> y = ms.Tensor(np.array([[0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1]]), ms.float32)
>>> y2 = ms.Tensor(np.array([[0.1, 0.25, 0.7, 0.9], [0.1, 0.25, 0.7, 0.9]]), ms.float32)
>>> metric = MyMAE().set_indexes([0, 2])
>>> metric.clear()
>>> # indexes is [0, 2], using x as logits, y2 as label.
>>> metric.update(x, y, y2)
>>> accuracy = metric.eval()
>>> print(accuracy)
1.399999976158142
>>> print(metric.indexes)
[0, 2]
abstract clear()[source]

An interface describes the behavior of clearing the internal evaluation result.

Note

All subclasses must override this interface.

Tutorial Examples:
abstract eval()[source]

An interface describes the behavior of computing the evaluation result.

Note

All subclasses must override this interface.

Tutorial Examples:
property indexes

Get the current indexes value. The default value is None and can be changed by set_indexes.

set_indexes(indexes)[source]

This interface is to rearrange the inputs of update.

Given (label0, label1, logits), set the indexes to [2, 1] then the (logits, label1) will be the actually inputs of update.

Note

When customize a metric, decorate the update function with the decorator mindspore.train.rearrange_inputs() for the indexes to take effect.

Parameters

indexes (List(int)) – The order of logits and labels to be rearranged.

Outputs:

Metric, its original Class instance.

Raises

ValueError – If the type of input 'indexes' is not a list or its elements are not all int.

abstract update(*inputs)[source]

An interface describes the behavior of updating the internal evaluation result.

Note

All subclasses must override this interface.

Parameters

inputs – A variable-length input argument list, usually are the logits and the corresponding labels.

Tutorial Examples: