比较与torch.bmm的功能差异
torch.bmm
torch.bmm(input, mat2, *, deterministic=False, out=None) -> Tensor
更多内容详见torch.bmm。
mindspore.ops.BatchMatMul
mindspore.ops.BatchMatMul(transpose_a=False, transpose_b=False)(x, y) -> Tensor
更多内容详见mindspore.ops.BatchMatMul。
差异对比
PyTorch:对input和mat2执行批量矩阵乘积,其中input和mat2必须是3-D的tensor。如果input是一个(b, n, m)的tensor,mat2是一个(b, n, p)的tensor,两者矩阵乘积的结果out为(b, n, p)。
MindSpore:MindSpore此API实现功能与PyTorch基本一致,不过MindSpore支持3D以及更高维度的矩阵乘法计算,其中MindSpore的transpose_a若为True,会把输入相乘的第一个tensor的最后两维进行交换。
分类 |
子类 |
PyTorch |
MindSpore |
差异 |
---|---|---|---|---|
参数 |
参数1 |
input |
x |
功能一致,参数名不同 |
参数2 |
mat2 |
y |
功能一致,参数名不同 |
|
参数3 |
deterministic |
- |
此参数只适用于稀疏的稀疏密集的CUDA bmm,MindSpore无此参数 |
|
参数4 |
out |
- |
不涉及 |
|
参数5 |
- |
transpose_a |
transpose_a若为True,会把输入相乘的第一个tensor的最后两维进行交换。 |
|
参数6 |
- |
transpose_b |
transpose_b若为True,会把输入相乘的第二个tensor的最后两维进行交换。 |
代码示例1
两API实现功能一致,用法相同。
# PyTorch
import numpy as np
import torch
from torch import tensor
input = torch.tensor(np.ones(shape=[2, 1, 5]), dtype=torch.float32)
mat2 = torch.tensor(np.ones(shape=[2, 5, 2]), dtype=torch.float32)
output = torch.bmm(input, mat2).numpy()
print(output)
# [[[5. 5.]]
# [[5. 5.]]]
# MindSpore
import numpy as np
import mindspore
import mindspore.ops as ops
from mindspore import Tensor
x = Tensor(np.ones(shape=[2, 1, 5]), mindspore.float32)
y = Tensor(np.ones(shape=[2, 5, 2]), mindspore.float32)
batmatmul = ops.BatchMatMul()
output = batmatmul(x, y)
print(output)
# [[[5. 5.]]
# [[5. 5.]]]
代码示例2
PyTorch只支持3D的tensor,MindSpore支持3D以及更高维度的矩阵乘法计算。
import numpy as np
import mindspore
import mindspore.ops as ops
from mindspore import Tensor
x = Tensor(np.ones(shape=[3, 5, 1, 3]), mindspore.float32)
y = Tensor(np.ones(shape=[3, 5, 3, 4]), mindspore.float32)
batmatmul = ops.BatchMatMul()
output = batmatmul(x, y)
print(output.shape)
# (3, 5, 1, 4)
代码示例3
MindSpore的transpose_a若为True,会把输入相乘的第一个tensor的最后两维进行交换,transpose_b若为True,会把输入相乘的第二个tensor的最后两维进行交换。
import numpy as np
import mindspore
import mindspore.ops as ops
from mindspore import Tensor
x = Tensor(np.ones(shape=[3, 5, 3, 1]), mindspore.float32)
y = Tensor(np.ones(shape=[3, 5, 3, 4]), mindspore.float32)
batmatmul = ops.BatchMatMul(transpose_a=True)
output = batmatmul(x, y)
print(output.shape)
# (3, 5, 1, 4)