mindspore.ops.ScatterAddWithAxis
- class mindspore.ops.ScatterAddWithAxis(axis=0)[source]
The output of the operation is produced by creating a copy of the input input_x, and then add updating its value to values specified by updates at specific index positions specified by indices.
Note
The three inputs input_x, updates and indices must have the same rank r >= 1.
- Parameters
axis (int, optional) – Specifies which axis to do scatter add, default: 0.
- Inputs:
input_x (Tensor) - The target tensor to be added.
indices (Tensor) - The index of input tensor whose data type is int32 or int64.
updates (Tensor) - The Tensor to update the input_x, has the same type as input_x and the same shape as indices.
- Outputs:
Tensor, the updated input_x, has the same shape and type as input_x.
- Raises
TypeError – If dtype of indices is neither int32 nor int64.
ValueError – If the shape of indices is not equal to the shape of updates.
- Supported Platforms:
CPU
Examples
>>> op = ops.ScatterAddWithAxis(0) >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32) >>> indices = Tensor(np.array([[1, 0, 2], [0, 2, 1]]), mindspore.int32) >>> updates = Tensor(np.array([[1, 1, 1], [1, 1, 1]]), mindspore.float32) >>> output = op(input_x, indices, updates) >>> print(output) [[ 2. 3. 3.] [ 5. 5. 7.] [ 7. 9. 10.]] >>> op = ops.ScatterAddWithAxis(1) >>> input_x = Tensor(np.array([[1, 2, 3, 4, 5]]), mindspore.int32) >>> indices = Tensor(np.array([[2, 4]]), mindspore.int32) >>> updates = Tensor(np.array([[8, 8]]), mindspore.int32) >>> output = op(input_x, indices, updates) >>> print(output) [[ 1 2 11 4 13]]