mindspore.ops.gather_elements
- mindspore.ops.gather_elements(input, dim, index)[源代码]
根据指定维度和索引获取元素。
说明
input 与 index 维度大小一致,当 input 中的 axis != dim 时 , index.shape[axis] <= input.shape[axis] 。
警告
在Ascend后端,以下场景将导致不可预测的行为:
正向执行流程中, 当 index 的取值不在范围 [-input.shape[dim], input.shape[dim]) 内;
反向执行流程中, 当 index 的取值不在范围 [0, input.shape[dim]) 内。
- 参数:
input (Tensor) - 输入tensor。
dim (int) - 指定维度。
index (Tensor) - 指定索引。
- 返回:
Tensor
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import mindspore >>> import numpy as np >>> x = mindspore.tensor(np.array([[1, 2], [3, 4]]), mindspore.int32) >>> index = mindspore.tensor(np.array([[0, 0], [1, 0]]), mindspore.int32) >>> dim = 1 >>> output = mindspore.ops.gather_elements(x, dim, index) >>> print(output) [[1 1] [4 3]]