比较与torch.unsqueeze的功能差异
torch.unsqueeze
torch.unsqueeze(input, dim) -> Tensor
更多内容详见torch.unsqueeze。
mindspore.ops.expand_dims
mindspore.ops.expand_dims(input_x, axis) -> Tensor
更多内容详见mindspore.ops.expand_dims。
差异对比
PyTorch:对输入input在给定的轴上添加额外维度。
MindSpore:MindSpore此API实现功能与PyTorch一致,仅参数名不同。
分类 |
子类 |
PyTorch |
MindSpore |
差异 |
---|---|---|---|---|
参数 |
参数1 |
input |
input_x |
功能一致,参数名不同 |
参数2 |
dim |
axis |
功能一致,参数名不同 |
代码示例
两API实现功能一致,用法相同。
# PyTorch
import torch
from torch import tensor
x = tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=torch.float32)
dim = 1
out = torch.unsqueeze(x,dim).numpy()
print(out)
# [[[ 1. 2. 3. 4.]]
# [[ 5. 6. 7. 8.]]
# [[ 9. 10. 11. 12.]]]
# MindSpore
import mindspore
import numpy as np
import mindspore.ops as ops
from mindspore import Tensor
input_params = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]), mindspore.float32)
axis = 1
output = ops.expand_dims(input_params, axis)
print(output)
# [[[ 1. 2. 3. 4.]]
# [[ 5. 6. 7. 8.]]
# [[ 9. 10. 11. 12.]]]