mindspore.ops.NoRepeatNGram
- class mindspore.ops.NoRepeatNGram(ngram_size=1)[source]
Updates the probability of occurrence of words with its corresponding n-grams.
During beam search, if consecutive ngram_size words exist in the generated word sequence, the consecutive ngram_size words will be avoided during subsequent prediction. For example, when ngram_size is 3, the generated word sequence is [1, 2, 3, 2, 3], the next predicted word will not be 2 and the value of log_probs will be replaced with -FLOAT_MAX. Because 3 consecutive words [2, 3, 2] do not appear twice in the word sequence.
- Parameters
ngram_size (int) – Size of n-grams, must be greater than 0. Default:
1
.
- Inputs:
state_seq (Tensor) - n-gram word series, a 3-D tensor with shape: \((batch\_size, beam\_width, m)\).
log_probs (Tensor) - Probability of occurrence of n-gram word series, a 3-D tensor with shape: \((batch\_size, beam\_width, vocab\_size)\). The value of log_probs will be replaced with -FLOAT_MAX when n-grams repeated.
- Outputs:
log_probs (Tensor) - The output Tensor with same shape and type as original log_probs.
- Raises
TypeError – If ngram_size is not an int.
TypeError – If neither state_seq nor log_probs is a Tensor.
TypeError – If the dtype of state_seq is not int.
TypeError – If the dtype of log_probs is not float.
ValueError – If ngram_size is less than zero.
ValueError – If ngram_size is greater than m.
ValueError – If neither state_seq nor log_probs is not a 3-D Tensor.
ValueError – If the batch_size of state_seq and log_probs are not equal.
ValueError – If the beam_width of state_seq and log_probs are not equal.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore >>> from mindspore import Tensor, ops >>> no_repeat_ngram = ops.NoRepeatNGram(ngram_size=3) >>> state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2], ... [9, 3, 9, 5, 4, 1, 5]], ... [[4, 8, 6, 4, 5, 6, 4], ... [4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32) >>> log_probs = Tensor([[[0.7, 0.8, 0.6, 0.9, 0.2, 0.8, 0.4, 0.6, 0.2, 0.7], ... [0.4, 0.5, 0.6, 0.7, 0.8, 0.1, 0.9, 0.8, 0.7, 0.1]], ... [[0.9, 0.7, 0.6, 0.3, 0.5, 0.3, 0.5, 0.4, 0.8, 0.6], ... [0.5, 0.8, 0.8, 0.7, 0.7, 0.8, 0.2, 0.7, 0.9, 0.7]]], dtype=mindspore.float32) >>> output = no_repeat_ngram(state_seq, log_probs) >>> print(output) [[[ 6.9999999e-01 -3.4028235e+38 6.0000002e-01 8.9999998e-01 2.0000000e-01 -3.4028235e+38 4.0000001e-01 6.0000002e-01 2.0000000e-01 6.9999999e-01] [ 4.0000001e-01 5.0000000e-01 6.0000002e-01 6.9999999e-01 8.0000001e-01 1.0000000e-01 8.9999998e-01 8.0000001e-01 6.9999999e-01 1.0000000e-01]] [[ 8.9999998e-01 6.9999999e-01 6.0000002e-01 3.0000001e-01 5.0000000e-01 -3.4028235e+38 5.0000000e-01 4.0000001e-01 8.0000001e-01 6.0000002e-01] [ 5.0000000e-01 8.0000001e-01 8.0000001e-01 6.9999999e-01 6.9999999e-01 8.0000001e-01 2.0000000e-01 6.9999999e-01 -3.4028235e+38 6.9999999e-01]]]