mindspore.ops.gather_elements

查看源文件
mindspore.ops.gather_elements(input, dim, index)[源代码]

根据指定维度和索引获取元素。

说明

inputindex 维度大小一致,当 input 中的 axis != dim 时 , index.shape[axis] <= input.shape[axis]

警告

在Ascend后端,以下场景将导致不可预测的行为:

  • 正向执行流程中, 当 index 的取值不在范围 [-input.shape[dim], input.shape[dim]) 内;

  • 反向执行流程中, 当 index 的取值不在范围 [0, input.shape[dim]) 内。

参数:
  • input (Tensor) - 输入tensor。

  • dim (int) - 指定维度。

  • index (Tensor) - 指定索引。

返回:

Tensor

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> import numpy as np
>>> x = mindspore.tensor(np.array([[1, 2], [3, 4]]), mindspore.int32)
>>> index = mindspore.tensor(np.array([[0, 0], [1, 0]]), mindspore.int32)
>>> dim = 1
>>> output = mindspore.ops.gather_elements(x, dim, index)
>>> print(output)
[[1 1]
 [4 3]]