mindspore.mint.gather
- mindspore.mint.gather(input, dim, index)[source]
Gather data from a tensor by indices.
\[output[(i_0, i_1, ..., i_{dim}, i_{dim+1}, ..., i_n)] = input[(i_0, i_1, ..., index[(i_0, i_1, ..., i_{dim}, i_{dim+1}, ..., i_n)], i_{dim+1}, ..., i_n)]\]Warning
On Ascend, the behavior is unpredictable in the following cases:
the value of index is not in the range [-input.shape[dim], input.shape[dim]) in forward;
the value of index is not in the range [0, input.shape[dim]) in backward.
- Parameters
input (Tensor) – The target tensor to gather values.
dim (int) – the axis to index along, must be in range [-input.rank, input.rank).
index (Tensor) –
The index tensor, with int32 or int64 data type. An valid index should be:
index.rank == input.rank;
for axis != dim, index.shape[axis] <= input.shape[axis];
the value of index is in range [-input.shape[dim], input.shape[dim]).
- Returns
Tensor, has the same type as input and the same shape as index.
- Raises
ValueError – If the shape of index is illegal.
ValueError – If dim is not in [-input.rank, input.rank).
ValueError – If the value of index is out of the valid range.
TypeError – If the type of index is illegal.
- Supported Platforms:
Ascend
Examples
>>> import mindspore >>> import numpy as np >>> from mindspore import Tensor, mint >>> input = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32) >>> index = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) >>> output = mint.gather(input, 1, index) >>> print(output) [[-0.1 -0.1] [0.5 0.5]]