mindspore.ops.tensor_scatter_elements

mindspore.ops.tensor_scatter_elements(input_x, indices, updates, axis=0, reduction='none')[源代码]

根据索引逐元素更新输入Tensor的值。 接受三个输入, input_xindicesupdates 。三者的秩都大于等于1。input_x 中的数据和 updates 中的数据会按照 indices 提取出来 做 reduction 指定的操作并更新结果到输出。下面看一个三维的例子:

output[indices[i][j][k]][j][k] = updates[i][j][k]  # if axis == 0, reduction == "none"

output[i][indices[i][j][k]][k] += updates[i][j][k]  # if axis == 1, reduction == "add"

output[i][j][indices[i][j][k]] = updates[i][j][k]  # if axis == 2, reduction == "none"

Warning

  • 如果 indices 中有多个索引向量对应于同一位置,则输出中该位置值是不确定的。

  • 在Ascend平台上,目前仅支持 reduction 设置为”none”的实现。

  • 在Ascend平台上,input_x 仅支持float16和float32两种数据类型。

Note

如果 indices 的某些值超出范围,则相应的 updates 不会更新到 input_x ,也不会抛出索引错误。

参数:

  • input_x (Tensor) - 输入Tensor。 input_x 其rank必须至少为1。

  • indices (Tensor) - 输入Tensor的索引,数据类型为int32或int64的。其rank必须和 input_x 一致。取值范围是[-s, s), 这里的s是 input_xaxis 指定轴的size。

  • updates (Tensor) - 指定与 input_x 进行reduction操作的Tensor,其数据类型与输入的数据类型相同。updates的shape必须等于indices的shape。

  • axis (int) - input_x reduction操作的轴,默认值是0。取值范围是[-r, r),其中r是 input_x 的秩。

  • reduction (str) - 指定进行的reduction操作。默认值是”none”,可选”add”。

返回:

Tensor,shape和数据类型与输入 input_x 相同。

异常:

  • TypeError - indices 的数据类型既不是int32,也不是int64。

  • ValueError - input_xindicesupdates 中,任意一者的秩小于1。

  • ValueError - updates 的shape和 indices 的shape不一样。

  • ValueError - updates 的秩和 input_x 的秩不一样。

  • RuntimeError - input_x 的数据类型和 updates 的数据类型不能隐式转换。

支持平台:

Ascend GPU CPU

样例:

>>> from mindspore.ops import functional as F
>>> input_x = Parameter(Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32), name="input_x")
>>> indices = Tensor(np.array([[1, 0, 2], [0, 2, 1]]), mindspore.int32)
>>> updates = Tensor(np.array([[1, 1, 1], [1, 1, 1]]), mindspore.float32)
>>> axis = 0
>>> reduction = "add"
>>> output = F.tensor_scatter_elements(input_x, indices, updates, axis, reduction)
>>> print(output)
[[ 2.0  3.0  3.0]
 [ 5.0  5.0  7.0]
 [ 7.0  9.0  10.0]]
>>> input_x = Parameter(Tensor(np.array([[1, 2, 3, 4, 5]]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([[2, 4]]), mindspore.int32)
>>> updates = Tensor(np.array([[8, 8]]), mindspore.float32)
>>> axis = 1
>>> reduction = "none"
>>> output = F.tensor_scatter_elements(input_x, indices, updates, axis, reduction)
>>> print(output)
[[ 1  2  8  4  8]]