mindspore.Tensor.index_put_

View Source On Gitee
Tensor.index_put_(indices, values, accumulate=False)[source]

Based on the indices in indices, replace the corresponding elements in Tensor self with the values in values. The expression Tensor.index_put_(indices, values) is equivalent to tensor[indices] = values. Update and return self.

Warning

The behavior is unpredictable in the following scenario:

  • If accumulate is False and indices contains duplicate elements.

Parameters
  • indices (tuple[Tensor], list[Tensor]) – the indices of type is bool, uint8, int32 or int64, used to index into the self. The size of indices should <= the rank of self and the tensors in indices should be broadcastable.

  • values (Tensor) – Tensor with the same type as self. If size == 1, it will be broadcastable.

  • accumulate (bool, optional) – If accumulate is True, the elements in values will be added to self, otherwise the elements in values will replace the corresponding elements in the self. Default: False.

Returns

Tensor self.

Raises
  • TypeError – If the dtype of the self is not equal to the dtype of values.

  • TypeError – If the dtype of indices is not tuple[Tensor], list[Tensor].

  • TypeError – If the dtype of tensors in indices are not bool, uint8, int32 or int64.

  • TypeError – If the dtypes of tensors in indices are inconsistent.

  • TypeError – If the dtype of accumulate is not bool.

  • ValueError – If size(values) is not 1 or max size of the tensors in indices when rank(self) == size(indices).

  • ValueError – If size(values) is not 1 or self.shape[-1] when rank(self) > size(indices).

  • ValueError – If the tensors in indices is not be broadcastable.

  • ValueError – If size(indices) > rank(self).

Supported Platforms:

Ascend

Examples

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
>>> values = Tensor(np.array([3]).astype(np.int32))
>>> indices = [Tensor(np.array([0, 1, 1]).astype(np.int32)), Tensor(np.array([1, 2, 1]).astype(np.int32))]
>>> accumulate = True
>>> x.index_put_(indices, values, accumulate)
>>> print(x)
[[1 5 3]
 [4 8 9]]