mindspore.ops.ScatterNdMul
- class mindspore.ops.ScatterNdMul(use_locking=False)[source]
Applies sparse multiplication to individual values or slices in a tensor.
Using given values to update parameter value through the multiplication operation, along with the input indices. This operation outputs the input_x after the update is done, which makes it convenient to use the updated value.
Warning
This is an experimental API that is subject to change or deletion.
Refer to
mindspore.ops.scatter_nd_mul()
for more details.- Parameters
use_locking (bool, optional) – Whether to protect the assignment by a lock. Default: False.
- Inputs:
input_x (Parameter) - The target tensor, with data type of Parameter.
indices (Tensor) - The index to do mul operation whose data type must be int32 or int64. The rank of indices must be at least 2 and indices.shape[-1] <= len(shape).
updates (Tensor) - The tensor to do the mul operation with input_x. The data type is same as input_x, and the shape is indices.shape[:-1] + x.shape[indices.shape[-1]:].
- Outputs:
Tensor, the updated input_x, has the same shape and type as input_x.
- Supported Platforms:
GPU
CPU
Examples
>>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x") >>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32) >>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32) >>> scatter_nd_mul = ops.ScatterNdMul() >>> output = scatter_nd_mul(input_x, indices, updates) >>> print(output) [ 1. 16. 18. 4. 35. 6. 7. 72.] >>> input_x = Parameter(Tensor(np.ones((4, 4, 4)), mindspore.int32)) >>> indices = Tensor(np.array([[0], [2]]), mindspore.int32) >>> updates = Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]], ... [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]]]), mindspore.int32) >>> scatter_nd_mul = ops.ScatterNdMul() >>> output = scatter_nd_mul(input_x, indices, updates) >>> print(output) [[[1 1 1 1] [2 2 2 2] [3 3 3 3] [4 4 4 4]] [[1 1 1 1] [1 1 1 1] [1 1 1 1] [1 1 1 1]] [[5 5 5 5] [6 6 6 6] [7 7 7 7] [8 8 8 8]] [[1 1 1 1] [1 1 1 1] [1 1 1 1] [1 1 1 1]]]