mindspore.ops.NoRepeatNGram

class mindspore.ops.NoRepeatNGram(ngram_size=1)[源代码]

Updates log_probs with repeat n-grams.

During beam search, if consecutive ngram_size words exist in the generated word sequence, the consecutive ngram_size words will be avoided during subsequent prediction. For example, when ngram_size is 3, the generated word sequence is [1, 2, 3, 2, 3], the next predicted word will not be 2 and the value of log_probs will be replaced with -FLOAT_MAX. Because 3 consecutive words [2, 3, 2] do not appear twice in the word sequence.

Parameters

ngram_size (int) – Size of n-grams, must be greater than 0. Default: 1.

Inputs:
  • state_seq (Tensor) - A 3-D tensor with shape: (batch_size, beam_width, m).

  • log_probs (Tensor) - A 3-D tensor with shape: (batch_size, beam_width, vocab_size). The value of log_probs will be replaced with -FLOAT_MAX when n-grams repeated.

Outputs:
  • log_probs (Tensor) - The output Tensor with same shape and type as original log_probs.

Raises
  • TypeError – If ngram_size is not an int.

  • TypeError – If neither state_seq nor log_probs is a Tensor.

Supported Platforms:

Ascend

Examples

>>> no_repeat_ngram = ops.NoRepeatNGram(ngram_size=3)
>>> state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2],
...                      [9, 3, 9, 5, 4, 1, 5]],
...                     [[4, 8, 6, 4, 5, 6, 4],
...                      [4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32)
>>> log_probs = Tensor([[[0.7, 0.8, 0.6, 0.9, 0.2, 0.8, 0.4, 0.6, 0.2, 0.7],
...                      [0.4, 0.5, 0.6, 0.7, 0.8, 0.1, 0.9, 0.8, 0.7, 0.1]],
...                     [[0.9, 0.7, 0.6, 0.3, 0.5, 0.3, 0.5, 0.4, 0.8, 0.6],
...                      [0.5, 0.8, 0.8, 0.7, 0.7, 0.8, 0.2, 0.7, 0.9, 0.7]]], dtype=mindspore.float32)
>>> output = no_repeat_ngram(state_seq, log_probs)
>>> print(output)
[[[ 6.9999999e-01 -3.4028235e+38  6.0000002e-01  8.9999998e-01
    2.0000000e-01 -3.4028235e+38  4.0000001e-01  6.0000002e-01
    2.0000000e-01  6.9999999e-01]
  [ 4.0000001e-01  5.0000000e-01  6.0000002e-01  6.9999999e-01
    8.0000001e-01  1.0000000e-01  8.9999998e-01  8.0000001e-01
    6.9999999e-01  1.0000000e-01]]
 [[ 8.9999998e-01  6.9999999e-01  6.0000002e-01  3.0000001e-01
    5.0000000e-01 -3.4028235e+38  5.0000000e-01  4.0000001e-01
    8.0000001e-01  6.0000002e-01]
  [ 5.0000000e-01  8.0000001e-01  8.0000001e-01  6.9999999e-01
    6.9999999e-01  8.0000001e-01  2.0000000e-01  6.9999999e-01
   -3.4028235e+38  6.9999999e-01]]]