mindspore.ops.SparseGatherV2

class mindspore.ops.SparseGatherV2[源代码]

基于指定的索引和axis返回输入Tensor的切片。

输入:

  • input_params (Tensor) - 被切片的Tensor。shape: \((x_1, x_2, ..., x_R)\)

  • input_indices (Tensor)- shape: \((y_1, y_2, ..., y_S)\) 。 指定切片的索引,取值须在 [0, input_params.shape[axis]) 范围内。

  • axis (int) - 进行索引的axis。

输出:

Tensor,shape: \((z_1, z_2, ..., z_N)\)

异常:

  • TypeError - axis 不是int类型。

支持平台:

Ascend GPU

样例:

>>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
>>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
>>> axis = 1
>>> out = ops.SparseGatherV2()(input_params, input_indices, axis)
>>> print(out)
[[2. 7.]
 [4. 54.]
 [2. 55.]]