mindspore.ops.scatter_update
- mindspore.ops.scatter_update(input_x, indices, updates)[source]
Updates the input tensor values using the given input indices and update values.
Note
Support implicit type conversion and type promotion.
Since Parameter objects do not support type conversion, an exception will be thrown when input_x is of a low-precision data type.
The updates with a shape of indices.shape + input_x.shape[1:] .
for each i, …, j in indices.shape:
- Parameters
- Returns
Tensor
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore >>> np_x = mindspore.tensor([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]], mindspore.float32) >>> input_x = mindspore.Parameter(np_x, name="x") >>> indices = mindspore.tensor([0, 1], mindspore.int32) >>> np_updates = mindspore.tensor([[2.0, 1.2, 1.0], [3.0, 1.2, 1.0]]) >>> updates = mindspore.tensor(np_updates, mindspore.float32) >>> output = mindspore.ops.scatter_update(input_x, indices, updates) >>> print(output) [[2. 1.2 1.] [3. 1.2 1.]]