mindspore.ops.ScatterAddWithAxis

class mindspore.ops.ScatterAddWithAxis(axis=0)[源代码]

该操作的输出是通过创建输入 input_x 的副本,然后将 updates 指定的值添加到 indices 指定的位置来更新副本中的值。

说明

三个输入 input_x, updatesindices 的秩相同且都大于等于1。

参数:
  • axis (int,可选) - 指定在哪个轴上进行散点加法。默认值:0。

输入:
  • input_x (Parameter) - 相加操作目标Tensor。

  • indices (Tensor) - 指定相加操作的索引,数据类型为int32或者int64。

  • updates (Tensor) - 指定与 input_x 相加操作的Tensor,数据类型与 input_x 相同,shape与 indices 相同。

输出:

Tensor,更新后的 input_x ,shape和数据类型与 input_x 相同。

异常:
  • TypeError - indices 不是int32或者int64。

  • ValueError - indicesupdates 的shape不一致。

支持平台:

CPU

样例:

>>> op = ops.ScatterAddWithAxis(0)
>>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
>>> indices = Tensor(np.array([[1, 0, 2], [0, 2, 1]]), mindspore.int32)
>>> updates = Tensor(np.array([[1, 1, 1], [1, 1, 1]]), mindspore.float32)
>>> output = op(input_x, indices, updates)
>>> print(output)
[[ 2.  3.  3.]
 [ 5.  5.  7.]
 [ 7.  9.  10.]]
>>> op = ops.ScatterAddWithAxis(1)
>>> input_x = Tensor(np.array([[1, 2, 3, 4, 5]]), mindspore.int32)
>>> indices = Tensor(np.array([[2, 4]]), mindspore.int32)
>>> updates = Tensor(np.array([[8, 8]]), mindspore.int32)
>>> output = op(input_x, indices, updates)
>>> print(output)
[[ 1  2  11  4  13]]