mindspore.ops.gather_nd
- mindspore.ops.gather_nd(input_x, indices)[源代码]
根据索引获取输入Tensor指定位置上的元素。
indices 是K维integer Tensor。假设 indices 是一个(K-1)维的张量,它的每个元素定义了 input_x 的一个slice:
\[output[(i_0, ..., i_{K-2})] = input\_x[indices[(i_0, ..., i_{K-2})]]\]indices 的最后一维的长度不能超过 input_x 的秩: \(indices.shape[-1] <= input\_x.rank\) 。
- 参数:
input_x (Tensor) - GatherNd的输入。
indices (Tensor) - 索引Tensor,其数据类型为int32或int64。
- 返回:
Tensor,数据类型与 input_x 相同,shape为 \(indices\_shape[:-1] + input\_x\_shape[indices\_shape[-1]:]\) 。
- 异常:
ValueError - input_x 的shape长度小于 indices 的最后一维的长度。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32) >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) >>> output = ops.gather_nd(input_x, indices) >>> print(output) [-0.1 0.5]