mindspore.mint.index_select

查看源文件
mindspore.mint.index_select(input, dim, index)[源代码]

返回一个新的Tensor,该Tensor沿维度 dimindex 中给定的索引对 input 进行选择。

返回的Tensor和输入Tensor( input )的维度数量相同,其第 dim 维度的大小和 index 的长度相同;其他维度和 input 相同。

说明

index的值必须在 [0, input.shape[dim]) 范围内,超出该范围结果未定义。

参数:
  • input (Tensor) - 输入Tensor。

  • dim (int) - 根据索引进行选择的维度。

  • index (Tensor) - 包含索引的一维Tensor。

返回:

Tensor,数据类型与输入 input 相同。

异常:
  • TypeError - inputindex 的类型不是Tensor。

  • TypeError - dim 的类型不是int。

  • ValueError - dim 值超出范围[-input.ndim, input.ndim - 1]。

  • ValueError - index 不是一维Tensor。

支持平台:

Ascend

样例:

>>> import mindspore
>>> from mindspore import Tensor, mint
>>> import numpy as np
>>> input = Tensor(np.arange(16).astype(np.float32).reshape(2, 2, 4))
>>> print(input)
[[[ 0.  1.  2.  3.]
[ 4.  5.  6.  7.]]
[[ 8.  9. 10. 11.]
[12. 13. 14. 15.]]]
>>> index = Tensor([0,], mindspore.int32)
>>> y = mint.index_select(input, 1, index)
>>> print(y)
[[[ 0.  1.  2.  3.]]
[[ 8.  9. 10. 11.]]]