mindspore.ops.Gather
- class mindspore.ops.Gather(*args, **kwargs)[source]
Returns a slice of the input tensor based on the specified indices and axis.
Slices the input tensor base on the indices at specified axis. See the following example for more clear.
- Inputs:
input_params (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\). The original Tensor.
input_indices (Tensor) - The shape of tensor is \((y_1, y_2, ..., y_S)\). Specifies the indices of elements of the original Tensor. Must be in the range [0, input_param.shape[axis]).
axis (int) - Specifies the dimension index to gather indices.
- Outputs:
Tensor, the shape of tensor is \(input\_params.shape[:axis] + input\_indices.shape + input\_params.shape[axis + 1:]\).
- Raises
TypeError – If axis is not an int.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32) >>> input_indices = Tensor(np.array([1, 2]), mindspore.int32) >>> axis = 1 >>> output = ops.Gather()(input_params, input_indices, axis) >>> print(output) [[ 2. 7.] [ 4. 54.] [ 2. 55.]] >>> axis = 0 >>> output = ops.Gather()(input_params, input_indices, axis) >>> print(output) [[3. 4. 54. 22.] [2. 2. 55. 3.]]