mindspore.ops.index_select
- mindspore.ops.index_select(input, axis, index)[源代码]
返回一个新的Tensor,该Tensor沿维度 axis 按 index 中给定的索引对 input 进行选择。
返回的Tensor和输入Tensor( input )的维度数量相同,其第 axis 维度的大小和 index 的长度相同;其他维度和 input 相同。
说明
index的值必须在 [0, input.shape[axis]) 范围内,超出该范围结果未定义。
- 参数:
input (Tensor) - 输入Tensor。
axis (int) - 根据索引进行选择的维度。
index (Tensor) - 包含索引的一维Tensor。
- 返回:
Tensor,数据类型与输入 input 相同。
- 异常:
TypeError - input 或 index 的类型不是Tensor。
TypeError - axis 的类型不是int。
ValueError - axis 值超出范围[-input.ndim, input.ndim - 1]。
ValueError - index 不是一维Tensor。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import mindspore >>> from mindspore import Tensor, ops >>> import numpy as np >>> input = Tensor(np.arange(16).astype(np.float32).reshape(2, 2, 4)) >>> print(input) [[[ 0. 1. 2. 3.] [ 4. 5. 6. 7.]] [[ 8. 9. 10. 11.] [12. 13. 14. 15.]]] >>> index = Tensor([0,], mindspore.int32) >>> y = ops.index_select(input, 1, index) >>> print(y) [[[ 0. 1. 2. 3.]] [[ 8. 9. 10. 11.]]]