mindspore.ops.NoRepeatNGram
- class mindspore.ops.NoRepeatNGram(ngram_size=1)[源代码]
Updates log_probs with repeat n-grams.
During beam search, if consecutive ngram_size words exist in the generated word sequence, the consecutive ngram_size words will be avoided during subsequent prediction. For example, when ngram_size is 3, the generated word sequence is [1, 2, 3, 2, 3], the next predicted word will not be 2 and the value of log_probs will be replaced with -FLOAT_MAX. Because 3 consecutive words [2, 3, 2] do not appear twice in the word sequence.
- Parameters
ngram_size (int) – Size of n-grams, must be greater than 0. Default: 1.
- Inputs:
state_seq (Tensor) - A 3-D tensor with shape: (batch_size, beam_width, m).
log_probs (Tensor) - A 3-D tensor with shape: (batch_size, beam_width, vocab_size). The value of log_probs will be replaced with -FLOAT_MAX when n-grams repeated.
- Outputs:
log_probs (Tensor) - The output Tensor with same shape and type as original log_probs.
- Raises
- Supported Platforms:
Ascend
Examples
>>> no_repeat_ngram = ops.NoRepeatNGram(ngram_size=3) >>> state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2], ... [9, 3, 9, 5, 4, 1, 5]], ... [[4, 8, 6, 4, 5, 6, 4], ... [4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32) >>> log_probs = Tensor([[[0.7, 0.8, 0.6, 0.9, 0.2, 0.8, 0.4, 0.6, 0.2, 0.7], ... [0.4, 0.5, 0.6, 0.7, 0.8, 0.1, 0.9, 0.8, 0.7, 0.1]], ... [[0.9, 0.7, 0.6, 0.3, 0.5, 0.3, 0.5, 0.4, 0.8, 0.6], ... [0.5, 0.8, 0.8, 0.7, 0.7, 0.8, 0.2, 0.7, 0.9, 0.7]]], dtype=mindspore.float32) >>> output = no_repeat_ngram(state_seq, log_probs) >>> print(output) [[[ 6.9999999e-01 -3.4028235e+38 6.0000002e-01 8.9999998e-01 2.0000000e-01 -3.4028235e+38 4.0000001e-01 6.0000002e-01 2.0000000e-01 6.9999999e-01] [ 4.0000001e-01 5.0000000e-01 6.0000002e-01 6.9999999e-01 8.0000001e-01 1.0000000e-01 8.9999998e-01 8.0000001e-01 6.9999999e-01 1.0000000e-01]] [[ 8.9999998e-01 6.9999999e-01 6.0000002e-01 3.0000001e-01 5.0000000e-01 -3.4028235e+38 5.0000000e-01 4.0000001e-01 8.0000001e-01 6.0000002e-01] [ 5.0000000e-01 8.0000001e-01 8.0000001e-01 6.9999999e-01 6.9999999e-01 8.0000001e-01 2.0000000e-01 6.9999999e-01 -3.4028235e+38 6.9999999e-01]]]