mindspore.Tensor.take

查看源文件
mindspore.Tensor.take(indices, axis=None, mode='clip')

在指定维度上获取Tensor中的元素。

参数:
  • indices (Tensor) - 待提取的值的shape为 \((Nj...)\) 的索引。

  • axis (int, 可选) - 在指定维度上选择值。默认情况下,使用展开的输入数组。默认值: None

  • mode (str, 可选) - 支持 'raise''wrap''clip'

    • raise:抛出错误。

    • wrap:绕接。

    • clip:裁剪到范围。 clip 模式意味着所有过大的索引都会被在指定轴方向上指向最后一个元素的索引替换。注:这将禁用具有负数的索引。

    默认值: 'clip'

返回:

Tensor,索引的结果。

异常:
  • ValueError - axis 超出范围,或 mode 被设置为 'raise''wrap''clip' 以外的值。

支持平台:

Ascend GPU CPU

样例:

>>> import numpy as np
>>> from mindspore import Tensor
>>> a = Tensor(np.array([4, 3, 5, 7, 6, 8]))
>>> indices = Tensor(np.array([0, 1, 4]))
>>> output = a.take(indices)
>>> print(output)
[4 3 6]
mindspore.Tensor.take(index)

选取给定索引 index 处的 self 元素。

警告

这是一个实验性API,后续可能修改或删除。

参数:
  • index (LongTensor) - 输入张量的索引张量。

返回:

Tensor,shape与索引的shape相同。

异常:
  • TypeError - 如果 index 的数据类型不是long。

样例:

>>> import mindspore as ms
>>> from mindspore import Tensor
>>> input = Tensor([[4, 3, 5],[6, 7, 8]], ms.float32)
>>> index = Tensor([0, 2, 5], ms.int64)
>>> output = input.take(index)
>>> print(output)
[4, 5, 8]