Differences with torch.scatter

View Source On Gitee

The following mapping relationships can be found in this file.

PyTorch APIs

MindSpore APIs

torch.scatter

mindspore.ops.scatter

torch.Tensor.scatter

mindspore.Tensor.scatter

torch.scatter

torch.scatter(input, dim, index, src)

For more information, see torch.scatter.

mindspore.ops.scatter

mindspore.ops.scatter(input, axis, index, src)

For more information, see mindspore.ops.scatter.

Differences

API function of MindSpore is not consistent with that of PyTorch.

PyTorch: For all dimensions d , index.size(d) <= src.size(d) is required, i.e. index can select some or all of the data of src to be scattered into input .

MindSpore: The shape of index must be the same as the shape of src , i.e. all data of src will be scattered into input by index .

Categories

Subcategories

PyTorch

MindSpore

Differences

Parameters

Parameter 1

input

input

Consistent

Parameter 2

dim

axis

Different parameter names

Parameter 3

index

index

For MindSpore, the shape of index must be the same as the shape of src . For PyTorch, index.size(d) <= src.size(d) is required for all dimensions d

Parameter 4

src

src

Consistent

Code Example 1

Perform scatter operation on part of src data.

# PyTorch
import torch
import numpy as np
input = torch.tensor(np.zeros((5, 5)), dtype=torch.float32)
src = torch.tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=torch.float32)
index = torch.tensor(np.array([[0, 1], [0, 1], [0, 1]]), dtype=torch.int64)
out = torch.scatter(input=input, dim=1, index=index, src=src)
print(out)
# tensor([[1., 2., 0., 0., 0.],
#         [4., 5., 0., 0., 0.],
#         [7., 8., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.]])

# MindSpore doesn't support this feature currently.

Code Example 2

Perform scatter operation on all of src data.

import torch
import numpy as np
input = torch.tensor(np.zeros((5, 5)), dtype=torch.float32)
src = torch.tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=torch.float32)
index = torch.tensor(np.array([[0, 1, 2], [0, 1, 2], [0, 1, 2]]), dtype=torch.int64)
out = torch.scatter(input=input, dim=1, index=index, src=src)
print(out)
# tensor([[1., 2., 3., 0., 0.],
#         [4., 5., 6., 0., 0.],
#         [7., 8., 9., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.]])

# MindSpore
import mindspore as ms
import numpy as np
input = ms.Tensor(np.zeros((5, 5)), dtype=ms.float32)
src = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=ms.float32)
index = ms.Tensor(np.array([[0, 1, 2], [0, 1, 2], [0, 1, 2]]), dtype=ms.int64)
out = ms.ops.scatter(input=input, axis=1, index=index, src=src)
print(out)
# [[1. 2. 3. 0. 0.]
#  [4. 5. 6. 0. 0.]
#  [7. 8. 9. 0. 0.]
#  [0. 0. 0. 0. 0.]
#  [0. 0. 0. 0. 0.]]