比较与torch.scatter_add的功能差异
torch.scatter_add
torch.scatter_add(input, dim, index, src)
更多内容详见torch.scatter_add。
mindspore.ops.tensor_scatter_elements
mindspore.ops.tensor_scatter_elements(input_x, indices, updates, axis, reduction)
差异对比
PyTorch:在任意维度 d
上,要求 index.size(d) <= src.size(d)
,即 index
可以选择 src
的部分或全部数据分散到 input
里。
MindSpore: indices
的shape必须和 updates
的shape一致,即 updates
的所有数据都会被 indices
分散到 input_x
里。
功能上无差异。
分类 |
子类 |
PyTorch |
MindSpore |
差异 |
---|---|---|---|---|
参数 |
参数 1 |
input |
input_x |
功能一致,参数名不同 |
参数 2 |
dim |
axis |
功能一致,参数名不同 |
|
参数 3 |
index |
indices |
MindSpore的 |
|
参数 4 |
src |
updates |
功能一致 |
|
参数 5 |
reduction |
MindSpore的 |
代码示例
# PyTorch
import torch
import numpy as np
x = torch.tensor(np.zeros((5, 5)), dtype=torch.float32)
src = torch.tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=torch.float32)
index = torch.tensor(np.array([[0, 1], [0, 1], [0, 1]]), dtype=torch.int64)
out = torch.scatter_add(x=x, dim=1, index=index, src=src)
print(out)
# tensor([[1., 2., 0., 0., 0.],
# [4., 5., 0., 0., 0.],
# [7., 8., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.]])
# MindSpore
import mindspore as ms
import numpy as np
x = ms.Tensor(np.zeros((5, 5)), dtype=ms.float32)
src = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=ms.float32)
index = ms.Tensor(np.array([[0, 1, 2], [0, 1, 2], [0, 1, 2]]), dtype=ms.int64)
out = ms.ops.tensor_scatter_elements(input_x=x, axis=1, indices=index, updates=src, reduction="add")
print(out)
# [[1. 2. 3. 0. 0.]
# [4. 5. 6. 0. 0.]
# [7. 8. 9. 0. 0.]
# [0. 0. 0. 0. 0.]
# [0. 0. 0. 0. 0.]]