mindspore.mint.scatter_add

View Source On Gitee
mindspore.mint.scatter_add(input, dim, index, src)[source]

Add all elements in src to the index specified by index to input along dimension specified by dim. It takes three inputs input, src and index of the same rank r >= 1.

For a 3-D tensor, the operation updates input as follows:

input[index[i][j][k]][j][k] += src[i][j][k]  # if dim == 0

input[i][index[i][j][k]][k] += src[i][j][k]  # if dim == 1

input[i][j][index[i][j][k]] += src[i][j][k]  # if dim == 2
Parameters
  • input (Tensor) – The target tensor. The rank must be at least 1.

  • dim (int) – Which dim to scatter. Accepted range is [-r, r) where r = rank(input). Default: 0.

  • index (Tensor) – The index of input to do scatter operation whose data type must be mindspore.int32 or mindspore.int64. Same rank as input. Except for the dimension specified by dim, the size of each dimension of index must be less than or equal to the size of the corresponding dimension of input.

  • src (Tensor) – The tensor doing the scatter operation with input, has the same type as input and the size of each dimension must be greater than or equal to that of index.

Returns

Tensor, has the same shape and type as input.

Raises
  • TypeError – If index is neither int32 nor int64.

  • ValueError – If anyone of the rank among input, index and src less than 1.

  • ValueError – If the rank of input, index and src is not the same.

  • ValueError – If, outside dimension dim, the size of any dimension of index is greater than the size of the corresponding dimension of input .

  • ValueError – If the size of any dimension of src is less than that of index.

Supported Platforms:

Ascend

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Tensor, mint
>>> input = Tensor(np.array([[1, 2, 3, 4, 5]]), dtype=ms.float32)
>>> src = Tensor(np.array([[8, 8]]), dtype=ms.float32)
>>> index = Tensor(np.array([[2, 4]]), dtype=ms.int64)
>>> out = mint.scatter_add(input=input, dim=1, index=index, src=src)
>>> print(out)
[[1. 2. 11. 4. 13.]]
>>> input = Tensor(np.zeros((5, 5)), dtype=ms.float32)
>>> src = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=ms.float32)
>>> index = Tensor(np.array([[0, 0, 0], [2, 2, 2], [4, 4, 4]]), dtype=ms.int64)
>>> out = mint.scatter_add(input=input, dim=0, index=index, src=src)
>>> print(out)
[[1. 2. 3. 0. 0.]
 [0. 0. 0. 0. 0.]
 [4. 5. 6. 0. 0.]
 [0. 0. 0. 0. 0.]
 [7. 8. 9. 0. 0.]]
>>> input = Tensor(np.zeros((5, 5)), dtype=ms.float32)
>>> src = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=ms.float32)
>>> index = Tensor(np.array([[0, 2, 4], [0, 2, 4], [0, 2, 4]]), dtype=ms.int64)
>>> out = mint.scatter_add(input=input, dim=1, index=index, src=src)
>>> print(out)
[[1. 0. 2. 0. 3.]
 [4. 0. 5. 0. 6.]
 [7. 0. 8. 0. 9.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]