比较与torch.cat的功能差异
torch.cat
torch.cat(
tensors,
dim=0,
out=None
)
更多内容详见torch.cat。
mindspore.ops.Concat
class mindspore.ops.Concat(
axis=0
)(input_x)
更多内容详见mindspore.ops.Concat。
使用方式
PyTorch: 输入tensor的数据类型不同时,低精度tensor会自动转成高精度tensor。
MindSpore: 当前要求输入tensor的数据类型保持一致,若不一致时可通过ops.Cast把低精度tensor转成高精度类型再调用Concat算子。
代码示例
import mindspore as ms
import mindspore.ops as ops
import torch
import numpy as np
# In MindSpore,converting low precision to high precision is needed before concat.
a = ms.Tensor(np.ones([2, 3]).astype(np.float16))
b = ms.Tensor(np.ones([2, 3]).astype(np.float32))
concat_op = ops.Concat()
cast_op = ops.Cast()
output = concat_op((cast_op(a, ms.float32), b))
print(output.shape)
# Out:
# (4, 3)
# In Pytorch.
a = torch.tensor(np.ones([2, 3]).astype(np.float16))
b = torch.tensor(np.ones([2, 3]).astype(np.float32))
output = torch.cat((a, b))
print(output.size())
# Out:
# torch.Size([4, 3])