mindspore.ops.NoRepeatNGram
- class mindspore.ops.NoRepeatNGram(ngram_size=1)[源代码]
n-grams出现重复,则更新对应n-gram词序列出现的概率。
在beam search过程中,如果连续的 ngram_size 个词存在已生成的词序列中,那么之后预测时,将避免再次出现这连续的 ngram_size 个词。例如:当 ngram_size 为3时,已生成的词序列为[1,2,3,2,3],则下一个预测的词不会为2,并且 log_probs 的值将替换成负FLOAT_MAX。因为连续的3个词2,3,2不会在词序列中出现两次。
- 参数:
ngram_size (int) - 指定n-gram的长度,必须大于0。默认值:
1
。
- 输入:
state_seq (Tensor) - n-gram词序列。是一个三维Tensor,其shape为:
。log_probs (Tensor) - n-gram词序列对应出现的概率,是一个三维Tensor,其shape为:
。当n-gram重复时,log_probs的值将被负FLOAT_MAX替换。
- 输出:
log_probs (Tensor) - 数据类型和shape与输入 log_probs 相同。
- 异常:
TypeError - 如果 ngram_size 不是int。
TypeError - 如果 state_seq 或 log_probs 不是Tensor。
TypeError - 如果 state_seq 的数据类型不是int。
TypeError - 如果 log_probs 的数据类型不是float。
ValueError - 如果 ngram_size 小于0。
ValueError - 如果 ngram_size 大于m。
ValueError - 如果 state_seq 或 log_probs 不是三维的Tensor。
ValueError - 如果 state_seq 和 log_probs 的batch_size不相等。
ValueError - 如果 state_seq 和 log_probs 的beam_width不相等。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import mindspore >>> from mindspore import Tensor, ops >>> no_repeat_ngram = ops.NoRepeatNGram(ngram_size=3) >>> state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2], ... [9, 3, 9, 5, 4, 1, 5]], ... [[4, 8, 6, 4, 5, 6, 4], ... [4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32) >>> log_probs = Tensor([[[0.7, 0.8, 0.6, 0.9, 0.2, 0.8, 0.4, 0.6, 0.2, 0.7], ... [0.4, 0.5, 0.6, 0.7, 0.8, 0.1, 0.9, 0.8, 0.7, 0.1]], ... [[0.9, 0.7, 0.6, 0.3, 0.5, 0.3, 0.5, 0.4, 0.8, 0.6], ... [0.5, 0.8, 0.8, 0.7, 0.7, 0.8, 0.2, 0.7, 0.9, 0.7]]], dtype=mindspore.float32) >>> output = no_repeat_ngram(state_seq, log_probs) >>> print(output) [[[ 6.9999999e-01 -3.4028235e+38 6.0000002e-01 8.9999998e-01 2.0000000e-01 -3.4028235e+38 4.0000001e-01 6.0000002e-01 2.0000000e-01 6.9999999e-01] [ 4.0000001e-01 5.0000000e-01 6.0000002e-01 6.9999999e-01 8.0000001e-01 1.0000000e-01 8.9999998e-01 8.0000001e-01 6.9999999e-01 1.0000000e-01]] [[ 8.9999998e-01 6.9999999e-01 6.0000002e-01 3.0000001e-01 5.0000000e-01 -3.4028235e+38 5.0000000e-01 4.0000001e-01 8.0000001e-01 6.0000002e-01] [ 5.0000000e-01 8.0000001e-01 8.0000001e-01 6.9999999e-01 6.9999999e-01 8.0000001e-01 2.0000000e-01 6.9999999e-01 -3.4028235e+38 6.9999999e-01]]]