Function Differences with tf.keras.metrics.AUC

View Source On Gitee


    num_thresholds=200, curve='ROC', summation_method='interpolation', name=None,
    dtype=None, thresholds=None

For more information, see tf.keras.metrics.AUC.


mindspore.train.auc(x, y, reorder=False)

For more information, see mindspore.train.auc.


TensorFlow: Input y_pred and y_true, and control whether the return value is based on the ROC curve or the Precision-Recall curve via the input curve. In addition, users can set their own parameters such as the number of thresholds num_thresholds and the threshold value thresholds. Support interpolate_pr_auc() method (there is no corresponding function in MindSpore). Please check the API interface for details of implementation and usage. TensorFlow version 1.15 only supports binary classification.

MindSpore: Before calling the mindspore.nn.auc interface, FPR(false positive rate) and TPR(true positive rate) should be derived using mindspore.nn.ROC, and the threshold value is determined by the y_pred element value size during calculation. The computed FPR and TPR are passed into mindspore.nn.auc for AUC calculation. Binary classification and multiclassification are supported.

Code Example

from mindspore.train import ROC, auc
import numpy as np

x = ms.Tensor(np.array([[0.28, 0.55, 0.15, 0.05], [0.10, 0.20, 0.05, 0.05], [0.20, 0.05, 0.15, 0.05],
                    [0.05, 0.05, 0.05, 0.75], [0.05, 0.05, 0.05, 0.75]]))
y = ms.Tensor(np.array([0, 1, 2, 3, 2]))
metric = ROC(class_num=4)
metric.update(x, y)
fpr, tpr, thresholds = metric.eval()
# out: [array([0.        , 0.33333333, 0.33333333, 0.66666667, 1.        ]), array([0.        , 0.33333333, 1.        ]),
# array([0.        , 0.33333333, 1.        ]), array([0.        , 0.33333333, 1.        ])]

# out: [array([0., 0., 1., 1., 1.]), array([0., 0., 1.]), array([0., 0., 1.]), array([0., 0., 1.])]

# calculate auc for class 0
output = auc(fpr[0], tpr[0])
# out: 0.6666666666666667

import tensorflow as tf

m = tf.keras.metrics.AUC(num_thresholds=3)
m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
# out: 0.75