mindspore.train.Accuracy

class mindspore.train.Accuracy(eval_type='classification')[source]

Calculates the accuracy for classification and multilabel data.

The accuracy class creates two local variables, the correct number and the total number that are used to compute the frequency with which y_pred matches y. This frequency is the accuracy.

\[\text{accuracy} =\frac{\text{true_positive} + \text{true_negative}} {\text{true_positive} + \text{true_negative} + \text{false_positive} + \text{false_negative}}\]
Parameters

eval_type (str) – The metric to calculate the accuracy over a dataset. Supports ‘classification’ and ‘multilabel’. ‘classification’ means the dataset label is single. ‘multilabel’ means the dataset has multiple labels. Default: 'classification' .

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor
>>> from mindspore.train import Accuracy
>>>
>>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mindspore.float32)
>>> y = Tensor(np.array([1, 0, 1]), mindspore.float32)
>>> metric = Accuracy('classification')
>>> metric.clear()
>>> metric.update(x, y)
>>> accuracy = metric.eval()
>>> print(accuracy)
0.6666666666666666
clear()[source]

Clears the internal evaluation result.

eval()[source]

Computes the accuracy.

Returns

np.float64, the computed result.

Raises

RuntimeError – If the sample size is 0.

update(*inputs)[source]

Updates the local variables. For ‘classification’, if the index of the maximum of the predict value matches the label, the predict result is correct. For ‘multilabel’, the predict value match the label, the predict result is correct.

Parameters

inputs

Logits and labels. y_pred stands for logits, y stands for labels. y_pred and y must be a Tensor, a list or an array.

  • For the ‘classification’ evaluation type, y_pred is a list of floating numbers in range \([0, 1]\) and the shape is \((N, C)\) in most cases (not strictly), where \(N\) is the number of cases and \(C\) is the number of categories. y must be in one-hot format that shape is \((N, C)\), or can be transformed to one-hot format that shape is \((N,)\).

  • For ‘multilabel’ evaluation type, the value of y_pred and y can only be 0 or 1, indices with 1 indicate the positive category. The shape of y_pred and y are both \((N, C)\).

Raises
  • ValueError – If the number of the inputs is not 2.

  • ValueError – class numbers of last input predicted data and current predicted data not match.