mindspore.ops.gather_elements

View Source On Gitee
mindspore.ops.gather_elements(input, dim, index)[source]

Gathers elements along the specified dim and indices.

Note

input and index have the same length of dimensions, and index.shape[axis] <= input.shape[axis] where axis goes through all dimensions of input except dim.

Warning

On Ascend, the behavior is unpredictable in the following cases:

  • the value of index is not in the range [-input.shape[dim], input.shape[dim]) in forward;

  • the value of index is not in the range [0, input.shape[dim]) in backward.

Parameters
  • input (Tensor) – The input tensor.

  • dim (int) – The specified dim.

  • index (Tensor) – The specified indices.

Returns

Tensor

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> x = mindspore.tensor([[1, 2], [3, 4]], mindspore.int32)
>>> index = mindspore.tensor([[0, 0], [1, 0]], mindspore.int32)
>>> dim = 1
>>> output = mindspore.ops.gather_elements(x, dim, index)
>>> print(output)
[[1 1]
 [4 3]]