mindspore.ops.GatherD

class mindspore.ops.GatherD[源代码]

获取指定轴的元素。

更多参考详见 mindspore.ops.gather_elements()

输入:
  • x (Tensor) - 输入Tensor。

  • dim (int) - 获取元素的轴。数据类型为int32或int64。只能是常量值。

  • index (Tensor) - 获取收集元素的索引。支持的数据类型包括:int32,int64。每个索引元素的取值范围为[-x_rank[dim], x_rank[dim])。

输出:

Tensor,数据类型与 x 相同。

支持平台:

Ascend GPU CPU

样例:

>>> x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.int32)
>>> index = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int32)
>>> dim = 1
>>> output = ops.GatherD()(x, dim, index)
>>> print(output)
[[1 1]
 [4 3]]