比较与torch.take的功能差异
torch.take
torch.take(input, index)
更多内容详见torch.take。
mindspore.Tensor.take
mindspore.Tensor.take(indices, axis=None, mode="clip")
更多内容详见mindspore.Tensor.take。
使用方式
基础功能为根据传入的索引从输入Tensor中获取对应的元素。
torch.take
首先将原始Tensor拉长,然后根据index
获取元素,index
设置值需小于输入Tensor的元素数。
mindspore.Tensor.take
默认状态下(axis=None
)同样先对Tensor做ravel
操作,再按照indices
返回元素。除此之外,可以通过axis
设定按照指定axis
选取元素。indices
数值可以超出Tensor元素数目,此时可以通过入参mode
设置不同的返回策略,具体说明请参考API注释。
代码示例
import mindspore as ms
import numpy as np
a = ms.Tensor([[1, 2, 8],[3, 4, 6]], ms.float32)
indices = ms.Tensor(np.array([1, 10]))
# take(self, indices, axis=None, mode='clip'):
print(a.take(indices))
# [2. 6.]
print(a.take(indices, axis=1))
# [[2. 8.]
# [4. 6.]]
print(a.take(indices, mode="wrap"))
# [2. 4.]
import torch
b = torch.tensor([[1, 2, 8],[3, 4, 6]])
indices = torch.tensor([1, 5])
print(torch.take(b, indices))
# tensor([2, 6])