mindspore.train.Metric
- class mindspore.train.Metric[source]
Base class of metric, which is used to evaluate metrics.
The clear, update, and eval should be called when evaluating metric, and they should be overridden by subclasse. update will accumulate intermediate results in the evaluation process, eval will evaluate the final result, and clear will reinitialize the intermediate results.
Never use this class directly, but instantiate one of its subclasses instead, for examples,
mindspore.train.MAE
,mindspore.train.Recall
etc.- Supported Platforms:
Ascend
GPU
CPU
- abstract clear()[source]
An interface describes the behavior of clearing the internal evaluation result.
Note
All subclasses must override this interface.
- Tutorial Examples:
- abstract eval()[source]
An interface describes the behavior of computing the evaluation result.
Note
All subclasses must override this interface.
- Tutorial Examples:
- property indexes
Get the current indexes value. The default value is None and can be changed by set_indexes.
- set_indexes(indexes)[source]
This interface is to rearrange the inputs of update.
Given (label0, label1, logits), set the indexes to [2, 1] then the (logits, label1) will be the actually inputs of update.
Note
When customize a metric, decorate the update function with the decorator
mindspore.train.rearrange_inputs()
for the indexes to take effect.- Parameters
indexes (List(int)) – The order of logits and labels to be rearranged.
- Outputs:
Metric
, its original Class instance.
- Raises
ValueError – If the type of input ‘indexes’ is not a list or its elements are not all int.
Examples
>>> import numpy as np >>> from mindspore import Tensor >>> from mindspore.train import Accuracy >>> >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) >>> y = Tensor(np.array([1, 0, 1])) >>> y2 = Tensor(np.array([0, 0, 1])) >>> metric = Accuracy('classification').set_indexes([0, 2]) >>> metric.clear() >>> # indexes is [0, 2], using x as logits, y2 as label. >>> metric.update(x, y, y2) >>> accuracy = metric.eval() >>> print(accuracy) 0.3333333333333333