比较与torch.unsqueeze的功能差异

torch.unsqueeze

torch.unsqueeze(input, dim) -> Tensor

更多内容详见torch.unsqueeze

mindspore.ops.expand_dims

mindspore.ops.expand_dims(input_x, axis) -> Tensor

更多内容详见mindspore.ops.expand_dims

差异对比

PyTorch:对输入input在给定的轴上添加额外维度。

MindSpore:MindSpore此API实现功能与PyTorch一致,仅参数名不同。

分类

子类

PyTorch

MindSpore

差异

参数

参数1

input

input_x

功能一致,参数名不同

参数2

dim

axis

功能一致,参数名不同

代码示例

两API实现功能一致,用法相同。

# PyTorch
import torch
from torch import tensor

x = tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=torch.float32)
dim = 1
out = torch.unsqueeze(x,dim).numpy()
print(out)
# [[[ 1.  2.  3.  4.]]
#  [[ 5.  6.  7.  8.]]
#  [[ 9. 10. 11. 12.]]]

# MindSpore
import mindspore
import numpy as np
import mindspore.ops as ops
from mindspore import Tensor

input_params = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]), mindspore.float32)
axis = 1
output = ops.expand_dims(input_params,  axis)
print(output)
# [[[ 1.  2.  3.  4.]]
#  [[ 5.  6.  7.  8.]]
#  [[ 9. 10. 11. 12.]]]