mindspore.ops.index_select

查看源文件
mindspore.ops.index_select(input, axis, index)[源代码]

返回一个新的Tensor,该Tensor沿维度 axisindex 中给定的索引对 input 进行选择。

返回的Tensor和输入Tensor( input )的维度数量相同,其第 axis 维度的大小和 index 的长度相同;其他维度和 input 相同。

说明

index的值必须在 [0, input.shape[axis]) 范围内,超出该范围结果未定义。

参数:
  • input (Tensor) - 输入Tensor。

  • axis (int) - 根据索引进行选择的维度。

  • index (Tensor) - 包含索引的一维Tensor。

返回:

Tensor,数据类型与输入 input 相同。

异常:
  • TypeError - inputindex 的类型不是Tensor。

  • TypeError - axis 的类型不是int。

  • ValueError - axis 值超出范围[-input.ndim, input.ndim - 1]。

  • ValueError - index 不是一维Tensor。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> from mindspore import Tensor, ops
>>> import numpy as np
>>> input = Tensor(np.arange(16).astype(np.float32).reshape(2, 2, 4))
>>> print(input)
[[[ 0.  1.  2.  3.]
  [ 4.  5.  6.  7.]]
 [[ 8.  9. 10. 11.]
  [12. 13. 14. 15.]]]
>>> index = Tensor([0,], mindspore.int32)
>>> y = ops.index_select(input, 1, index)
>>> print(y)
[[[ 0.  1.  2.  3.]]
 [[ 8.  9. 10. 11.]]]