mindspore.ops.multinomial
- mindspore.ops.multinomial(input, num_samples, replacement=True, seed=None)[source]
Returns a tensor sampled from the multinomial probability distribution located in the corresponding row of the input tensor.
Note
The rows of input do not need to sum to one (in which case we use the values as weights), but must be non-negative, finite and have a non-zero sum.
- Parameters
input (Tensor) – The input tensor containing probabilities, must be 1 or 2 dimensions, with float32 data type.
num_samples (int) – Number of samples to draw.
replacement (bool, optional) – Whether to draw with replacement or not. Default:
True
.seed (int, optional) – Seed is used as entropy source for the random number engines to generate pseudo-random numbers, must be non-negative. Default:
None
.
- Returns
Tensor, has the same rows with input. The number of sampled indices of each row is num_samples. The dtype is float32.
- Raises
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore >>> from mindspore import Tensor, ops >>> from mindspore import dtype as mstype >>> # case 1: The output is random, and the length of the output is the same as num_sample. >>> input = Tensor([0, 9, 4, 0], mindspore.float32) >>> output = ops.multinomial(input, 2) >>> # print(output) >>> # [1 2] or [2 1] >>> # the case where the result is [2 1] in multiple times. >>> # This is because the value corresponding to the index 1 is larger than the value of the index 2. >>> print(len(output)) 2 >>> # case 2: The output is random, and the length of the output is the same as num_sample. >>> # replacement is False(Default). >>> # If the extracted value is 0, the index value of 1 will be returned. >>> input = Tensor([0, 9, 4, 0], mstype.float32) >>> output = ops.multinomial(input, 4) >>> print(output) [1 1 2 1] >>> # case 3: The output is random, num_sample == x_length = 4, and replacement is True, >>> # Can extract the same elements。 >>> input = Tensor([0, 9, 4, 0], mstype.float32) >>> output = ops.multinomial(input, 4, True) >>> print(output) [1 1 2 2]