mindspore.ops.index_select

查看源文件
mindspore.ops.index_select(input, axis, index)[源代码]

根据指定轴和索引对输入tensor进行选取,返回一个新tensor。

说明

  • index 的值必须在 [0, input.shape[axis]) 范围内,超出该范围的结果未定义。

  • 返回的tensor和输入tensor的维度数量相同,其第 axis 维度的大小和 index 的长度相同,其他维度和 input 相同。

参数:
  • input (Tensor) - 输入tensor。

  • axis (int) - 指定轴。

  • index (Tensor) - 指定索引,一维tensor。

返回:

Tensor

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> import numpy as np
>>> input = mindspore.tensor(np.arange(16).astype(np.float32).reshape(2, 2, 4))
>>> print(input)
[[[ 0.  1.  2.  3.]
  [ 4.  5.  6.  7.]]
 [[ 8.  9. 10. 11.]
  [12. 13. 14. 15.]]]
>>> index = mindspore.tensor([0,], mindspore.int32)
>>> y = mindspore.ops.index_select(input, 1, index)
>>> print(y)
[[[ 0.  1.  2.  3.]]
 [[ 8.  9. 10. 11.]]]