mindspore.mint.scatter_add
- mindspore.mint.scatter_add(input, dim, index, src)[source]
Add all elements in src to the index specified by index to input along dimension specified by dim. It takes three inputs input, src and index of the same rank r >= 1.
For a 3-D tensor, the operation updates input as follows:
input[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 input[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 input[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
- Parameters
input (Tensor) – The target tensor. The rank must be at least 1.
dim (int) – Which dim to scatter. Accepted range is [-r, r) where r = rank(input).
index (Tensor) – The index of input to do scatter operation whose data type must be mindspore.int32 or mindspore.int64. Same rank as input. Except for the dimension specified by dim, the size of each dimension of index must be less than or equal to the size of the corresponding dimension of input.
src (Tensor) – The tensor doing the scatter operation with input, has the same type as input and the size of each dimension must be greater than or equal to that of index.
- Returns
Tensor, has the same shape and type as input.
- Raises
TypeError – If index is neither int32 nor int64.
ValueError – If anyone of the rank among input, index and src is less than 1.
ValueError – If the rank of input, index and src is not the same.
ValueError – The size of any dimension of index except the dimension specified by dim is greater than the size of the corresponding dimension of input.
ValueError – If the size of any dimension of src is less than that of index.
- Supported Platforms:
Ascend
Examples
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import Tensor, mint >>> input = Tensor(np.array([[1, 2, 3, 4, 5]]), dtype=ms.float32) >>> src = Tensor(np.array([[8, 8]]), dtype=ms.float32) >>> index = Tensor(np.array([[2, 4]]), dtype=ms.int64) >>> out = mint.scatter_add(input=input, dim=1, index=index, src=src) >>> print(out) [[1. 2. 11. 4. 13.]] >>> input = Tensor(np.zeros((5, 5)), dtype=ms.float32) >>> src = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=ms.float32) >>> index = Tensor(np.array([[0, 0, 0], [2, 2, 2], [4, 4, 4]]), dtype=ms.int64) >>> out = mint.scatter_add(input=input, dim=0, index=index, src=src) >>> print(out) [[1. 2. 3. 0. 0.] [0. 0. 0. 0. 0.] [4. 5. 6. 0. 0.] [0. 0. 0. 0. 0.] [7. 8. 9. 0. 0.]] >>> input = Tensor(np.zeros((5, 5)), dtype=ms.float32) >>> src = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), dtype=ms.float32) >>> index = Tensor(np.array([[0, 2, 4], [0, 2, 4], [0, 2, 4]]), dtype=ms.int64) >>> out = mint.scatter_add(input=input, dim=1, index=index, src=src) >>> print(out) [[1. 0. 2. 0. 3.] [4. 0. 5. 0. 6.] [7. 0. 8. 0. 9.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.]]