比较与torch.Tensor.scatter_的差异

查看源文件

torch.Tensor.scatter_

torch.Tensor.scatter_(dim, index, src, reduce) -> Tensor

更多内容详见torch.Tensor.scatter_

mindspore.ops.tensor_scatter_elements

mindspore.ops.tensor_scatter_elements(
    input_x,
    indices,
    updates,
    axis=0,
    reduction='none'
) -> Tensor

更多内容详见mindspore.ops.tensor_scatter_elements

差异对比

PyTorch:用给定的值替换Tensor中指定索引位置的元素。

MindSpore:MindSpore此算子实现功能与PyTorch一致,PyTorch中该接口为Tensor接口,调用方式略有不同。

分类

子类

PyTorch

MindSpore

差异

参数

参数1

dim

axis

功能一致,参数名不同

参数2

index

indices

功能一致,参数名不同

参数3

src

updates

功能一致,参数名不同

参数4

reduce

reduction

规约计算方式,目前MindSpore仅支持“none”和“add”模式

参数5

-

input_x

PyTorch中该接口为Tensor接口

代码示例

两API实现功能一致。

# PyTorch
import torch

t = torch.zeros((3, 4), dtype=torch.float32)
indices = torch.tensor([[1, 2], [0, 1]])
values = torch.tensor([[3, 4], [5, 6]], dtype=torch.float32)
t.scatter_(0, indices, values)
print(t)
# tensor([[5., 0., 0., 0.],
#         [3., 6., 0., 0.],
#         [0., 4., 0., 0.]])

# MindSpore
import numpy as np
import mindspore
from mindspore import Tensor, Parameter
from mindspore import ops

input_x = Parameter(Tensor(np.zeros((3, 4)), mindspore.float32), name="x")
indices = Tensor(np.array([[1, 2], [0, 1]]), mindspore.int32)
updates = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
axis = 0
reduction = "none"
output = ops.tensor_scatter_elements(input_x, indices, updates, axis, reduction)
print(output)
# [[5. 0. 0. 0.]
#  [3. 6. 0. 0.]
#  [0. 4. 0. 0.]]