mindspore.ops.tensor_scatter_elements
- mindspore.ops.tensor_scatter_elements(input_x, indices, updates, axis=0, reduction='none')[source]
Return a new tensor by performing a specified operation update on input_x at the specified indices with the given update values.
Not support implicit type conversion.
For example: the output of a 3-D tensor is
output[indices[i][j][k]][j][k] = updates[i][j][k] # if axis == 0, reduction == "none" output[i][indices[i][j][k]][k] += updates[i][j][k] # if axis == 1, reduction == "add" output[i][j][indices[i][j][k]] = updates[i][j][k] # if axis == 2, reduction == "none"
Warning
The order in which updates are applied is nondeterministic, meaning that if there are multiple index vectors in indices that correspond to the same position, the value of that position in the output will be nondeterministic.
On Ascend, the reduction only support set to "none" for now.
On Ascend, the data type of input_x must be float16 or float32.
This is an experimental API that is subject to change or deletion.
Note
If some values of the indices exceed the upper or lower bounds of the index of input_x, instead of raising an index error, the corresponding updates will not be updated to input_x. The backward is supported only for the case updates.shape == indices.shape.
- Parameters
input_x (Tensor) – The input tensor. The rank must be at least 1.
indices (Tensor) – The specified indices.
updates (Tensor) – The update values.
axis (int) – The axis along which to index. Default
0
.reduction (str) –
The specified operation, supports
none
,add
.If
none
, updates will be assigned to input_x according to indices.If
add
, updates will be added to input_x according to indices. Defaultnone
.
- Returns
Tensor
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore >>> input_x = mindspore.tensor([[1, 2, 3, 4, 5]]) >>> indices = mindspore.tensor([[2, 4]]) >>> updates = mindspore.tensor([[8, 8]]) >>> output = mindspore.ops.tensor_scatter_elements(input_x, indices, updates, axis=1, reduction="none") >>> print(output) [[1 2 8 4 8]] >>> output = mindspore.ops.tensor_scatter_elements(input_x, indices, updates, axis=1, reduction="add") >>> print(output) [[ 1 2 11 4 13]]