mindspore.ops.tensor_scatter_elements

View Source On Gitee
mindspore.ops.tensor_scatter_elements(input_x, indices, updates, axis=0, reduction='none')[source]

Return a new tensor by performing a specified operation update on input_x at the specified indices with the given update values.

Not support implicit type conversion.

For example: the output of a 3-D tensor is

output[indices[i][j][k]][j][k] = updates[i][j][k]  # if axis == 0, reduction == "none"

output[i][indices[i][j][k]][k] += updates[i][j][k]  # if axis == 1, reduction == "add"

output[i][j][indices[i][j][k]] = updates[i][j][k]  # if axis == 2, reduction == "none"

Warning

  • The order in which updates are applied is nondeterministic, meaning that if there are multiple index vectors in indices that correspond to the same position, the value of that position in the output will be nondeterministic.

  • On Ascend, the reduction only support set to "none" for now.

  • On Ascend, the data type of input_x must be float16 or float32.

  • This is an experimental API that is subject to change or deletion.

Note

If some values of the indices exceed the upper or lower bounds of the index of input_x, instead of raising an index error, the corresponding updates will not be updated to input_x. The backward is supported only for the case updates.shape == indices.shape.

Parameters
  • input_x (Tensor) – The input tensor. The rank must be at least 1.

  • indices (Tensor) – The specified indices.

  • updates (Tensor) – The update values.

  • axis (int) – The axis along which to index. Default 0.

  • reduction (str) –

    The specified operation, supports none , add .

    • If none, updates will be assigned to input_x according to indices.

    • If add, updates will be added to input_x according to indices. Default none.

Returns

Tensor

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> input_x = mindspore.tensor([[1, 2, 3, 4, 5]])
>>> indices = mindspore.tensor([[2, 4]])
>>> updates = mindspore.tensor([[8, 8]])
>>> output = mindspore.ops.tensor_scatter_elements(input_x, indices, updates, axis=1, reduction="none")
>>> print(output)
[[1 2 8 4 8]]
>>> output = mindspore.ops.tensor_scatter_elements(input_x, indices, updates, axis=1, reduction="add")
>>> print(output)
[[ 1  2 11  4 13]]