mindspore.train.Perplexity

class mindspore.train.Perplexity(ignore_label=None)[源代码]

计算困惑度(perplexity)。困惑度是衡量一个概率分布或语言模型好坏的标准。低困惑度表明语言模型可以很好地预测样本。计算方式如下:

\[PP(W)=P(w_{1}w_{2}...w_{N})^{-\frac{1}{N}}=\sqrt[N]{\frac{1}{P(w_{1}w_{2}...w_{N})}}\]

其中 \(w\) 代表语料库中的单词。根号内是句子概率的倒数,句子越好(概率大),困惑度越小。

参数:
  • ignore_label (Union[int, None]) - 计数时要忽略的无效标签的索引。如果设置为None,它将包括所有条目。默认值: None

支持平台:

Ascend GPU CPU

样例:

>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore.train import Perplexity
>>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
>>> y = Tensor(np.array([1, 0, 1]))
>>> metric = Perplexity(ignore_label=None)
>>> metric.clear()
>>> metric.update(x, y)
>>> perplexity = metric.eval()
>>> print(perplexity)
2.231443166940565
clear()[源代码]

内部评估结果清零。

eval()[源代码]

返回当前评估结果。

返回:

numpy.float64,计算得到的困惑度结果。

异常:
  • RuntimeError - 样本量为0。

update(*inputs)[源代码]

使用 predslabels 更新内部评估结果。

参数:
  • inputs - 输入 predslabelspredslabels 是Tensor、list或numpy.ndarray。 preds 是预测值, labels 是数据的标签。 predslabels 的shape都是 \((N, C)\)

异常:
  • ValueError - 输入数量不是2。

  • RuntimeError - 预测值和标签的长度不同。

  • RuntimeError - 预测值和标签的shape不同。