比较与torch.bucketize的差异
torch.bucketize
torch.bucketize(input, boundaries, *, out_int32=False, right=False, out=None)
更多内容详见torch.bucketize。
mindspore.ops.bucketize
class mindspore.ops.bucketize(input, boundaries, *, right=False)
更多内容详见mindspore.ops.bucketize。
使用方式
MindSpore此API功能与PyTorch一致,参数支持的数据类型有差异。
PyTorch:input
支持scalar和Tensor类型,boundaries
支持Tensor类型,且可以通过 out_int32
指定返回的索引的数据类型。
MindSpore:input
支持Tensor类型,boundaries
支持list类型,无 out_int32
参数。
分类 |
子类 |
PyTorch |
MindSpore |
差异 |
---|---|---|---|---|
参数 |
参数1 |
input |
input |
功能一致,支持数据类型不同 |
参数2 |
boundaries |
boundaries |
功能一致,支持数据类型不同 |
|
参数3 |
out_int32 |
- |
PyTorch的 |
|
参数4 |
right |
right |
一致 |
|
参数5 |
out |
- |
PyTorch的 |
代码示例
import torch
boundaries = torch.tensor([1, 3, 5, 7, 9])
v = torch.tensor([[3, 6, 9], [3, 6, 9]])
out1 = torch.bucketize(v, boundaries)
out2 = torch.bucketize(v, boundaries, right=True)
print(out1)
# Out:
# tensor([[1, 3, 4],
# [1, 3, 4]])
print(out2)
# Out:
# tensor([[2, 3, 5],
# [2, 3, 5]])
from mindspore import Tensor, ops
boundaries = [1, 3, 5, 7, 9]
v = Tensor([[3, 6, 9], [3, 6, 9]])
out1 = ops.bucketize(v, boundaries)
out2 = ops.bucketize(v, boundaries, right=True)
print(out1)
# Out:
# [[1 3 4]
# [1 3 4]]
print(out2)
# Out:
# [[2 3 5]
# [2 3 5]]