mindspore.ops.Multinomial

class mindspore.ops.Multinomial(seed=0, seed2=0, dtype=mstype.int32)[source]

Returns a tensor sampled from the multinomial probability distribution located in the corresponding row of tensor input.

Note

  • The rows of input do not need to sum to one (in which case we use the values as weights), but must be non-negative, finite and have a non-zero sum.

  • Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms, and the random seed is the initial value of this random number. If the random seed is the same in two separate calls, the random number generated will not change.

  • Global random seed and operator-level random seed are not set or both set to 0: behavior is completely random.

  • Global random seed is set, but operator-level random seed is not set: A global random seed will splice with 0 to generate random number.

  • Global random seed is not set, operator-level random seed is set: 0 splices with the operator-level random seed to generate random number.

  • Both Global random and operator-level random seed are set: the global random seed will splice with the operator-level random seed to generate random number.

Parameters
  • seed (int, optional) – The operator-level random seed, used to generate random numbers, must be non-negative. Default: 0 .

  • seed2 (int, optional) – The global random seed, which combines with the operator-level random seed to determine the final generated random number, must be non-negative. Default: 0 .

  • dtype (mindspore.dtype, optional) – The type of output, must be mstype.int32 or mstype.int64. Default: mstype.int32.

Inputs:
  • x (Tensor) - the input tensor containing the cumsum of probabilities, must be 1 or 2 dimensions.

  • num_samples (int) - number of samples to draw, must be a nonnegative number.

Outputs:

Tensor with the same rows as x, each row has num_samples sampled indices.

Raises
  • TypeError – If neither seed nor seed2 is an int.

  • TypeError – If dtype of num_samples is not int.

  • TypeError – If dtype is not mstype.int32 or mstype.int64.

  • ValueError – If seed or seed2 is less than 0.

Supported Platforms:

Ascend GPU CPU

Examples

>>> from mindspore import Tensor, ops
>>> from mindspore import dtype as mstype
>>> x = Tensor([[0., 9., 4., 0.]], mstype.float32)
>>> multinomial = ops.Multinomial(seed=10)
>>> output = multinomial(x, 2)
>>> print(output)
[[1 1]]