mindspore.ops.scatter_update

View Source On Gitee
mindspore.ops.scatter_update(input_x, indices, updates)[source]

Updates the input tensor values using the given input indices and update values.

Note

  • Support implicit type conversion and type promotion.

  • Since Parameter objects do not support type conversion, an exception will be thrown when input_x is of a low-precision data type.

  • The updates with a shape of indices.shape + input_x.shape[1:] .

for each i, …, j in indices.shape:

input_x[indices[i,...,j],:]=updates[i,...,j,:]
Parameters
  • input_x (Union[Parameter, Tensor]) – The input parameter or tensor.

  • indices (Tensor) – Specify the indices for update operation. If there are duplicates in indices, the order for updating is undefined.

  • updates (Tensor) – The values to update.

Returns

Tensor

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> np_x = mindspore.tensor([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]], mindspore.float32)
>>> input_x = mindspore.Parameter(np_x, name="x")
>>> indices = mindspore.tensor([0, 1], mindspore.int32)
>>> np_updates =  mindspore.tensor([[2.0, 1.2, 1.0], [3.0, 1.2, 1.0]])
>>> updates =  mindspore.tensor(np_updates, mindspore.float32)
>>> output = mindspore.ops.scatter_update(input_x, indices, updates)
>>> print(output)
[[2. 1.2  1.]
 [3. 1.2  1.]]