mindspore.Tensor.gather

View Source On Gitee
Tensor.gather(dim, index) Tensor

Gather data from a tensor by indices.

\[output[(i_0, i_1, ..., i_{dim}, i_{dim+1}, ..., i_n)] = input[(i_0, i_1, ..., index[(i_0, i_1, ..., i_{dim}, i_{dim+1}, ..., i_n)], i_{dim+1}, ..., i_n)]\]

Warning

On Ascend, the behavior is unpredictable in the following cases:

  • the value of index is not in the range [-self.shape[dim], self.shape[dim]) in forward;

  • the value of index is not in the range [0, self.shape[dim]) in backward.

Parameters
  • dim (int) – the axis to index along, must be in range [-self.rank, self.rank).

  • index (Tensor) –

    The index tensor, with int32 or int64 data type. An valid index should be:

    • index.rank == self.rank;

    • for axis != dim, index.shape[axis] <= self.shape[axis];

    • the value of index is in range [-self.shape[dim], self.shape[dim]).

Returns

Tensor, has the same type as self and the same shape as index.

Raises
  • ValueError – If the shape of index is illegal.

  • ValueError – If dim is not in [-self.rank, self.rank).

  • ValueError – If the value of index is out of the valid range.

  • TypeError – If the type of index is illegal.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor
>>> input = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
>>> index = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> output = input.gather(1, index)
>>> print(output)
[[-0.1 -0.1]
 [0.5   0.5]]
Tensor.gather(input_indices, axis, batch_dims=0) Tensor

Returns the slice of the input tensor corresponding to the elements of input_indices on the specified axis.

The following figure shows the calculation process of Gather commonly:

../../../_images/Gather.png

where params represents the input input_params, and indices represents the index to be sliced input_indices.

Note

  1. The value of input_indices must be in the range of [0, input_param.shape[axis]). On CPU and GPU, an error is raised if an out of bound indice is found. On Ascend, the results may be undefined.

  2. The data type of self cannot be bool_ on Ascend platform currently.

Parameters
  • input_indices (Tensor) – Index tensor to be sliced, the shape of tensor is \((y_1, y_2, ..., y_S)\). Specifies the indices of elements of the original Tensor. The data type can be int32 or int64.

  • axis (Union(int, Tensor[int])) – Specifies the dimension index to gather indices. It must be greater than or equal to batch_dims. When axis is a Tensor, the size must be 1.

  • batch_dims (int) – Specifies the number of batch dimensions. It must be less than or euqal to the rank of input_indices. Default: 0 .

Returns

Tensor, the shape of tensor is \(input\_params.shape[:axis] + input\_indices.shape[batch\_dims:] + input\_params.shape[axis + 1:]\).

Raises
  • TypeError – If axis is not an int or Tensor.

  • ValueError – If axis is a Tensor and its size is not 1.

  • TypeError – If self is not a tensor.

  • TypeError – If input_indices is not a tensor of type int.

  • RuntimeError – If input_indices is out of range [0, input_param.shape[axis]) on CPU or GPU.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor
>>> # case1: input_indices is a Tensor with shape (5, ).
>>> input_params = Tensor(np.array([1, 2, 3, 4, 5, 6, 7]), mindspore.float32)
>>> input_indices = Tensor(np.array([0, 2, 4, 2, 6]), mindspore.int32)
>>> axis = 0
>>> output = input_params.gather(input_indices=input_indices, axis=axis)
>>> print(output)
[1. 3. 5. 3. 7.]
>>> # case2: input_indices is a Tensor with shape (2, 2). When the input_params has one dimension,
>>> # the output shape is equal to the input_indices shape.
>>> input_indices = Tensor(np.array([[0, 2], [2, 6]]), mindspore.int32)
>>> axis = 0
>>> output = input_params.gather(input_indices=input_indices, axis=axis)
>>> print(output)
[[1. 3.]
 [3. 7.]]
>>> # case3: input_indices is a Tensor with shape (2, ) and
>>> # input_params is a Tensor with shape (3, 4) and axis is 0.
>>> input_params = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]), mindspore.float32)
>>> input_indices = Tensor(np.array([0, 2]), mindspore.int32)
>>> axis = 0
>>> output = input_params.gather(input_indices=input_indices, axis=axis)
>>> print(output)
[[ 1.  2.  3.  4.]
 [ 9. 10. 11. 12.]]
>>> # case4: input_indices is a Tensor with shape (2, ) and
>>> # input_params is a Tensor with shape (3, 4) and axis is 1, batch_dims is 1.
>>> input_params = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]), mindspore.float32)
>>> input_indices = Tensor(np.array([0, 2, 1]), mindspore.int32)
>>> axis = 1
>>> batch_dims = 1
>>> output = input_params.gather(input_indices, axis, batch_dims)
>>> print(output)
[ 1.  7. 10.]