mindspore.ops.index_fill

mindspore.ops.index_fill(x, axis, index, value)[source]

Fills the elements under the axis dimension of the input Tensor x with the input value by selecting the indices in the order given in index.

Parameters
  • x (Tensor) – Input Tensor. The supported data type is Number or Bool.

  • axis (Union[int, Tensor]) – Dimension along which to fill the input Tensor. Only supports an int number or a 0-dimensional Tensor, whose data type is int32 or int64.

  • index (Tensor) – Indices of the input Tensor to fill in. The dtype must be int32.

  • value (Union[bool, int, float, Tensor]) – Value to fill the returned Tensor. If value is a Tensor, it must be a 0-dimensional Tensor and has the same dtype as x. Otherwise, the value will be cast to a 0-dimensional Tensor with the same data type as x.

Returns

Tensor, has the same dtype and shape as input Tensor.

Raises
  • TypeError – If x is not a Tensor.

  • TypeError – If axis is neither int number nor Tensor.

  • TypeError – When axis is a Tensor, its dtype is not int32 or int64.

  • TypeError – If index is not a Tensor.

  • TypeError – If dtype of index is not int32.

  • TypeError – If value is not a bool, int, float, or Tensor.

  • TypeError – When value is a Tensor, the dtype of x and value are not the same.

  • ValueError – If axis is a Tensor and its rank is not equal to 0.

  • ValueError – If the rank of index is greater than 1D.

  • ValueError – When value is a Tensor and its rank is not equal to 0.

  • RuntimeError – If the value of axis is out the range of [-x.ndim, x.ndim - 1].

  • RuntimeError – If the values of index are out the range of [-x.shape[axis], x.shape[axis]-1].

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> import numpy as np
>>> from mindspore import ops
>>> from mindspore import Tensor
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32))
>>> index = Tensor([0, 2], mindspore.int32)
>>> value = Tensor(-2.0, mindspore.float32)
>>> y = ops.index_fill(x, 1, index, value)
>>> print(y)
[[-2. 2. -2.]
 [-2. 5. -2.]
 [-2. 8. -2.]]