mindspore.ops.multinomial

mindspore.ops.multinomial(inputs, num_sample, replacement=True, seed=None)[source]

Returns a tensor sampled from the multinomial probability distribution located in the corresponding row of the input tensor.

Note

The rows of input do not need to sum to one (in which case we use the values as weights), but must be non-negative, finite and have a non-zero sum.

Parameters
  • x (Tensor) – The input tensor containing probabilities, must be 1 or 2 dimensions, with float32 data type.

  • num_sample (int) – Number of samples to draw.

  • replacement (bool, optional) – Whether to draw with replacement or not, default True.

  • seed (int, optional) – Seed is used as entropy source for the random number engines to generate pseudo-random numbers, must be non-negative. Default: 0.

Outputs:

Tensor, has the same rows with input. The number of sampled indices of each row is num_samples. The dtype is float32.

Raises
  • TypeError – If x is not a Tensor whose dtype is not float32.

  • TypeError – If num_sample is not an int.

  • TypeError – If seed is neither an int nor a optional.

Supported Platforms:

GPU

Examples

>>> # case 1: The output is random, and the length of the output is the same as num_sample.
>>> x = Tensor([0, 9, 4, 0], mindspore.float32)
>>> output = ops.multinomial(x, 2)
>>> # print(output)
>>> # [1 2] or [2 1]
>>> # the case where the result is [2 1] in multiple times.
>>> # This is because the value corresponding to the index 1 is larger than the value of the index 2.
>>> print(len(output))
2
>>> # case 2: The output is random, and the length of the output is the same as num_sample.
>>> # replacement is False(Default).
>>> # If the extracted value is 0, the index value of 1 will be returned.
>>> x = Tensor([0, 9, 4, 0], mstype.float32)
>>> output = ops.multinomial(x, 4)
>>> print(output)
[1 1 2 1]
>>> # case 3: num_sample == x_length = 4, and replacement is True, Can extract the same elements。
>>> x = Tensor([0, 9, 4, 0], mstype.float32)
>>> output = ops.multinomial(x, 4, True)
>>> print(output)
[1 1 2 2]