mindspore.ops.tensor_scatter_elements

查看源文件
mindspore.ops.tensor_scatter_elements(input_x, indices, updates, axis=0, reduction='none')[源代码]

updates 中所有的元素按照 reduction 指定的归约操作写入 input_xindices 指定的索引处。 axis 控制scatter操作的方向。 input_xindicesupdates 三者的rank都必须大于或等于1。

下面看一个三维的例子:

output[indices[i][j][k]][j][k] = updates[i][j][k]  # if axis == 0, reduction == "none"

output[i][indices[i][j][k]][k] += updates[i][j][k]  # if axis == 1, reduction == "add"

output[i][j][indices[i][j][k]] = updates[i][j][k]  # if axis == 2, reduction == "none"

警告

  • 如果 indices 中有多个索引向量对应于同一位置,则输出中该位置值是不确定的。

  • 在Ascend平台上,目前仅支持 reduction 设置为 "none" 的实现。

  • 在Ascend平台上,input_x 仅支持float16和float32两种数据类型。

说明

如果 indices 的值超出 input_x 索引上下界,则相应的 updates 不会更新到 input_x ,也不会抛出索引错误。

警告

这是一个实验性API,后续可能修改或删除。

参数:
  • input_x (Tensor) - 输入Tensor。rank必须大于等于1。

  • indices (Tensor) - input_x 要进行scatter操作的目标索引。数据类型为int32或int64,rank必须和 input_x 一致,取值范围是[-s, s),s是 input_xaxis 指定轴的size。

  • updates (Tensor) - 指定与 input_x 进行scatter操作的Tensor,其数据类型与 input_x 类型相同,shape与 indices 的shape相同。

  • axis (int) - input_x 执行scatter操作的轴。取值范围是[-r, r),其中r是 input_x 的rank。默认值: 0

  • reduction (str) - 指定进行的规约操作。支持 "none""add" 。默认值: "none" 。当 reduction 设置为 "none" 时,updates 将根据 indices 赋值给 input_x。当 reduction 设置为 "add" 时,updates 将根据 indices 累加到 input_x

返回:

Tensor,shape和数据类型与输入 input_x 相同。

异常:
  • TypeError - indices 的数据类型不满足int32或int64。

  • ValueError - input_xindicesupdates 中,任意一者的rank小于1。

  • ValueError - updates 的shape和 indices 的shape不一致。

  • ValueError - updates 的rank和 input_x 的rank不一致。

  • RuntimeError - input_x 的数据类型和 updates 的数据类型不能隐式转换。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> from mindspore import Tensor, ops
>>> from mindspore import Parameter
>>> import numpy as np
>>> input_x = Parameter(Tensor(np.array([[1, 2, 3, 4, 5]]), mindspore.int32), name="x")
>>> indices = Tensor(np.array([[2, 4]]), mindspore.int32)
>>> updates = Tensor(np.array([[8, 8]]), mindspore.int32)
>>> axis = 1
>>> reduction = "none"
>>> output = ops.tensor_scatter_elements(input_x, indices, updates, axis, reduction)
>>> print(output)
[[1 2 8 4 8]]
>>> input_x = Parameter(Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.int32), name="x")
>>> indices = Tensor(np.array([[1, -1, 2], [0, 2, 1]]), mindspore.int32)
>>> updates = Tensor(np.array([[1, 2, 2], [4, 5, 8]]), mindspore.int32)
>>> axis = 0
>>> reduction = "add"
>>> output = ops.tensor_scatter_elements(input_x, indices, updates, axis, reduction)
>>> print(output)
[[ 5  2  3]
 [ 5  5 14]
 [ 7 15 11]]