mindspore.Tensor.index_select

Tensor.index_select(axis, index) Tensor

Generates a new Tensor that accesses the values of self along the specified axis dimension using the indices specified in index. The new Tensor has the same number of dimensions as self, with the size of the axis dimension being equal to the length of index, and the size of all other dimensions will be unchanged from the original self Tensor.

Note

The value of index must be in the range of [0, self.shape[axis]), the result is undefined out of range.

Parameters
  • axis (int) – The dimension to be indexed.

  • index (Tensor) – A 1-D Tensor with the indices to access in self along the specified axis.

Returns

Tensor, has the same dtype as self Tensor.

Raises
  • TypeError – If index is not a Tensor.

  • TypeError – If axis is not int number.

  • ValueError – If the value of axis is out the range of [-self.ndim, self.ndim - 1].

  • ValueError – If the dimension of index is not equal to 1.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> from mindspore import Tensor
>>> import numpy as np
>>> input = Tensor(np.arange(16).astype(np.float32).reshape(2, 2, 4))
>>> print(input)
[[[ 0.  1.  2.  3.]
  [ 4.  5.  6.  7.]]
 [[ 8.  9. 10. 11.]
  [12. 13. 14. 15.]]]
>>> index = Tensor([0,], mindspore.int32)
>>> y = input.index_select(1, index)
>>> print(y)
[[[ 0.  1.  2.  3.]]
 [[ 8.  9. 10. 11.]]]
Tensor.index_select(dim, index) Tensor

Note

The value of index must be in the range of [0, self.shape[dim]), the result is undefined out of range.

Parameters
  • dim (int) – The dimension to be indexed.

  • index (Tensor) – A 1-D Tensor with the indices to access in self along the specified dim.

Returns

Tensor, has the same dtype as self Tensor.

Raises
  • TypeError – If index is not a Tensor.

  • TypeError – If dim is not int number.

  • ValueError – If the value of dim is out the range of [-self.ndim, self.ndim - 1].

  • ValueError – If the dimension of index is not equal to 1.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> from mindspore import Tensor
>>> import numpy as np
>>> input = Tensor(np.arange(16).astype(np.float32).reshape(2, 2, 4))
>>> print(input)
[[[ 0.  1.  2.  3.]
  [ 4.  5.  6.  7.]]
 [[ 8.  9. 10. 11.]
  [12. 13. 14. 15.]]]
>>> index = Tensor([0,], mindspore.int32)
>>> y = input.index_select(1, index)
>>> print(y)
[[[ 0.  1.  2.  3.]]
 [[ 8.  9. 10. 11.]]]