Function Differences with torch.cat

View Source On Gitee

torch.cat

torch.cat(
    tensors,
    dim=0,
    out=None
)

For more information, see torch.cat.

mindspore.ops.Concat

class mindspore.ops.Concat(
    axis=0
)(input_x)

For more information, see mindspore.ops.Concat.

Differences

PyTorch: When the data type of the input tensors are different, the low-precision tensor will be automatically converted to a high-precision tensor.

MindSpore: Currently, the data type of the input tensors are required to remain the same. If not, the low-precision tensor can be converted to a high-precision tensor through ops.Cast and then call the Concat operator.

Code Example

import mindspore
import mindspore.ops as ops
import torch
import numpy as np

# In MindSpore,converting low precision to high precision is needed before concat.
a = ms.Tensor(np.ones([2, 3]).astype(np.float16))
b = ms.Tensor(np.ones([2, 3]).astype(np.float32))
concat_op = ops.Concat()
cast_op = ops.Cast()
output = concat_op((cast_op(a, ms.float32), b))
print(output.shape)
# Out:
# (4, 3)

# In Pytorch.
a = torch.tensor(np.ones([2, 3]).astype(np.float16))
b = torch.tensor(np.ones([2, 3]).astype(np.float32))
output = torch.cat((a, b))
print(output.size())
# Out:
# torch.Size([4, 3])