mindspore.ops.ScatterAddWithAxis
- class mindspore.ops.ScatterAddWithAxis(axis=0)[源代码]
该操作的输出是通过创建输入 input_x 的副本,然后将 updates 指定的值添加到 indices 指定的位置来更新副本中的值。
说明
三个输入 input_x, updates 和 indices 的秩相同且都大于等于1。
- 参数:
axis (int,可选) - 指定在哪个轴上进行散点加法。默认值:0。
- 输入:
input_x (Parameter) - 相加操作目标Tensor。
indices (Tensor) - 指定相加操作的索引,数据类型为int32或者int64。
updates (Tensor) - 指定与 input_x 相加操作的Tensor,数据类型与 input_x 相同,shape与 indices 相同。
- 输出:
Tensor,更新后的 input_x ,shape和数据类型与 input_x 相同。
- 异常:
TypeError - indices 不是int32或者int64。
ValueError - indices 和 updates 的shape不一致。
- 支持平台:
CPU
样例:
>>> op = ops.ScatterAddWithAxis(0) >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32) >>> indices = Tensor(np.array([[1, 0, 2], [0, 2, 1]]), mindspore.int32) >>> updates = Tensor(np.array([[1, 1, 1], [1, 1, 1]]), mindspore.float32) >>> output = op(input_x, indices, updates) >>> print(output) [[ 2. 3. 3.] [ 5. 5. 7.] [ 7. 9. 10.]] >>> op = ops.ScatterAddWithAxis(1) >>> input_x = Tensor(np.array([[1, 2, 3, 4, 5]]), mindspore.int32) >>> indices = Tensor(np.array([[2, 4]]), mindspore.int32) >>> updates = Tensor(np.array([[8, 8]]), mindspore.int32) >>> output = op(input_x, indices, updates) >>> print(output) [[ 1 2 11 4 13]]