比较与torch.Tensor.scatter_的差异
torch.Tensor.scatter_
torch.Tensor.scatter_(dim, index, src, reduce) -> Tensor
更多内容详见torch.Tensor.scatter_。
mindspore.ops.tensor_scatter_elements
mindspore.ops.tensor_scatter_elements(
input_x,
indices,
updates,
axis=0,
reduction='none'
) -> Tensor
差异对比
PyTorch:用给定的值替换Tensor中指定索引位置的元素。
MindSpore:MindSpore此算子实现功能与PyTorch一致,PyTorch中该接口为Tensor接口,调用方式略有不同。
分类 |
子类 |
PyTorch |
MindSpore |
差异 |
---|---|---|---|---|
参数 |
参数1 |
dim |
axis |
功能一致,参数名不同 |
参数2 |
index |
indices |
功能一致,参数名不同 |
|
参数3 |
src |
updates |
功能一致,参数名不同 |
|
参数4 |
reduce |
reduction |
规约计算方式,目前MindSpore仅支持“none”和“add”模式 |
|
参数5 |
- |
input_x |
PyTorch中该接口为Tensor接口 |
代码示例
两API实现功能一致。
# PyTorch
import torch
t = torch.zeros((3, 4), dtype=torch.float32)
indices = torch.tensor([[1, 2], [0, 1]])
values = torch.tensor([[3, 4], [5, 6]], dtype=torch.float32)
t.scatter_(0, indices, values)
print(t)
# tensor([[5., 0., 0., 0.],
# [3., 6., 0., 0.],
# [0., 4., 0., 0.]])
# MindSpore
import numpy as np
import mindspore
from mindspore import Tensor, Parameter
from mindspore import ops
input_x = Parameter(Tensor(np.zeros((3, 4)), mindspore.float32), name="x")
indices = Tensor(np.array([[1, 2], [0, 1]]), mindspore.int32)
updates = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
axis = 0
reduction = "none"
output = ops.tensor_scatter_elements(input_x, indices, updates, axis, reduction)
print(output)
# [[5. 0. 0. 0.]
# [3. 6. 0. 0.]
# [0. 4. 0. 0.]]