Differences with torch.multinomial
The following mapping relationships can be found in this file.
PyTorch APIs |
MindSpore APIs |
---|---|
torch.multinomial |
mindspore.ops.multinomial |
torch.Tensor.multinomial |
mindspore.Tensor.multinomial |
torch.multinomial
torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None)
For more information, see torch.multinomial.
mindspore.ops.multinomial
mindspore.ops.multinomial(input, num_samples, replacement=True, seed=None)
For more information, see mindspore.ops.multinomial.
Differences
API function of MindSpore is consistent with that of PyTorch.
MindSpore: The default value of the parameter replacement
is True
, which means the sampled data is put back after each sampling.
PyTorch: The default value of the parameter replacement
is False
, which means the sampled data is not put back after each sampling.
Categories |
Subcategories |
PyTorch |
MindSpore |
Differences |
---|---|---|---|---|
Parameters |
Parameter 1 |
input |
input |
Consistent |
Parameter 2 |
num_samples |
num_samples |
Consistent |
|
Parameter 3 |
replacement |
replacement |
The default value for PyTorch is |
|
Parameter 4 |
generator |
seed |
For details, see General Difference Parameter Table |
|
Parameter 5 |
out |
- |
For details, see General Difference Parameter Table |
Code Example
# PyTorch
import torch
input = torch.tensor([0, 9, 4, 0], dtype=torch.float32)
output = torch.multinomial(input, 2)
print(output)
# tensor([1, 2]) or tensor([2, 1])
# MindSpore
import mindspore as ms
input = ms.Tensor([0, 9, 4, 0], dtype=ms.float32)
output = ms.ops.multinomial(input, 2, False)
print(output)
# [1 2] or [2 1]