Differences with torch.scatter
The following mapping relationships can be found in this file.
PyTorch APIs |
MindSpore APIs |
---|---|
torch.scatter |
mindspore.ops.scatter |
torch.Tensor.scatter |
mindspore.Tensor.scatter |
torch.scatter
torch.scatter(input, dim, index, src)
For more information, see torch.scatter.
mindspore.ops.scatter
mindspore.ops.scatter(input, axis, index, src)
For more information, see mindspore.ops.scatter.
Differences
API function of MindSpore is not consistent with that of PyTorch.
PyTorch: For all dimensions d
, index.size(d) <= src.size(d)
is required, i.e. index
can select some or all of the data of src
to be scattered into input
.
MindSpore: The shape of index
must be the same as the shape of src
, i.e. all data of src
will be scattered into input
by index
.
Categories |
Subcategories |
PyTorch |
MindSpore |
Differences |
---|---|---|---|---|
Parameters |
Parameter 1 |
input |
input |
Consistent |
Parameter 2 |
dim |
axis |
Different parameter names |
|
Parameter 3 |
index |
index |
For MindSpore, the shape of |
|
Parameter 4 |
src |
src |
Consistent |
Code Example 1
Perform scatter operation on part of
src
data.
# PyTorch
import torch
import numpy as np
input = torch.tensor(np.zeros((5, 5)), dtype=torch.float32)
src = torch.tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=torch.float32)
index = torch.tensor(np.array([[0, 1], [0, 1], [0, 1]]), dtype=torch.int64)
out = torch.scatter(input=input, dim=1, index=index, src=src)
print(out)
# tensor([[1., 2., 0., 0., 0.],
# [4., 5., 0., 0., 0.],
# [7., 8., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.]])
# MindSpore doesn't support this feature currently.
Code Example 2
Perform scatter operation on all of
src
data.
import torch
import numpy as np
input = torch.tensor(np.zeros((5, 5)), dtype=torch.float32)
src = torch.tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=torch.float32)
index = torch.tensor(np.array([[0, 1, 2], [0, 1, 2], [0, 1, 2]]), dtype=torch.int64)
out = torch.scatter(input=input, dim=1, index=index, src=src)
print(out)
# tensor([[1., 2., 3., 0., 0.],
# [4., 5., 6., 0., 0.],
# [7., 8., 9., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.]])
# MindSpore
import mindspore as ms
import numpy as np
input = ms.Tensor(np.zeros((5, 5)), dtype=ms.float32)
src = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=ms.float32)
index = ms.Tensor(np.array([[0, 1, 2], [0, 1, 2], [0, 1, 2]]), dtype=ms.int64)
out = ms.ops.scatter(input=input, axis=1, index=index, src=src)
print(out)
# [[1. 2. 3. 0. 0.]
# [4. 5. 6. 0. 0.]
# [7. 8. 9. 0. 0.]
# [0. 0. 0. 0. 0.]
# [0. 0. 0. 0. 0.]]