比较与torch.nn.unfold的差异
torch.nn.Unfold
torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)(input) -> Tensor
更多内容详见torch.nn.Unfold。
mindspore.nn.Unfold
class mindspore.nn.Unfold(ksizes, strides, rates, padding='valid')(x) -> Tensor
更多内容详见mindspore.nn.Unfold。
差异对比
PyTorch:从批处理输入张量中提取出滑动的局部区域块。输入张量的shape为(N, C, H, W),其中N为批处理大小,C为通道数,H表示高度,W表示宽度。输出是三维的Tensor。
MindSpore:MindSpore此API实现功能与PyTorch功能有差异。PyTorch的kernel_size、stride和dilation支持int和tuple输入,padding支持在输入的两侧添加的隐式零填充。而MindSpore的ksizes、strides和rates三个参数的格式必须是(1, row, col, 1),padding参数支持两种格式same和valid。MindSpore输入是四维张量,shape为(in_batch, in_depth, in_row, int_col),输出是shape为(out_batch, out_depth, out_row, out_col)的四维Tensor,其中out_batch和in_batch相同。
分类 |
子类 |
PyTorch |
MindSpore |
差异 |
---|---|---|---|---|
输入 |
单输入 |
input |
x |
都是输入四维的Tensor,数据格式为NCHW |
参数 |
参数1 |
kernel_size |
ksizes |
功能一致,但输入格式不一致。都表示滑动窗口的大小,PyTorch支持int和tuple输入,如果kernel_size是一个int,其值将在所有维度上进行复制;MindSpore支持格式为(1, ksize_row, ksize_col, 1)的tuple或list |
参数2 |
dilation |
rates |
功能一致,但输入格式不一致。dilation表示控制滑动过程中所跨越元素的个数,支持int和tuple输入,默认值是1,如果dilation是一个int,其值将在所有维度上进行复制;rates表示滑窗元素之间的空洞个数,支持格式为(1, rate_row, rate_col, 1)的tuple或list |
|
参数3 |
padding |
padding |
功能不一致。都表示填充模式,PyTorch是在输入的两侧进行零填充,支持int和tuple输入,默认值是0,如果padding是一个int,其值将在所有维度上进行复制;MindSpore支持str输入,可选值有"same"或"valid",默认值是"valid",表示所提取的区域块被原始输入所覆盖,取值为"same"时表示所提取的区域块的部分区域可以在原始输入之外进行零填充 |
|
参数4 |
stride |
strides |
功能一致,但输入格式不一致。都表示空间维度上滑动窗口的步长,PyTorch支持int和tuple输入,默认值是1,如果stride是一个int,其值将在所有维度上进行复制;MindSpore支持格式为(1, stride_row, stride_col, 1)的tuple或list |
代码示例1
PyTorch的stride默认值是1,dilation默认值是1,padding默认值是0,由于是输入是四维Tensor且这三个参数默认值都是int,将在所有维度上进行复制。为了得到与PyTorch相同的结果,MindSpore首先分别将Unfold算子的strides、rates和padding分别设置为(1, 1, 1, 1)、(1, 1, 1, 1)和"valid",若kernel_size为一个int时即kernel_size=a时,将ksizes设置为(1, a, a, 1);若kernel_size为一个tuple时即kernel_size=(a,b)时,将ksizes设置为(1, a, b, 1),其次为了输出结果完全一致,首先将MindSpore输出结果进行Reshape操作,然后通过下面的操作进行Concat得到最终结果。
# PyTorch
import torch
from torch import Tensor
import numpy as np
input = Tensor(np.arange(32).reshape((1, 2, 4, 4)).astype(np.float32))
torch_unfold = torch.nn.Unfold(kernel_size=(2, 2), dilation=1, padding=0, stride=1)
torch_output = torch_unfold(input)
# torch_output.shape=(1, 8, 9)
print(torch_output.numpy())
# [[[ 0. 1. 2. 4. 5. 6. 8. 9. 10.]
# [ 1. 2. 3. 5. 6. 7. 9. 10. 11.]
# [ 4. 5. 6. 8. 9. 10. 12. 13. 14.]
# [ 5. 6. 7. 9. 10. 11. 13. 14. 15.]
# [16. 17. 18. 20. 21. 22. 24. 25. 26.]
# [17. 18. 19. 21. 22. 23. 25. 26. 27.]
# [20. 21. 22. 24. 25. 26. 28. 29. 30.]
# [21. 22. 23. 25. 26. 27. 29. 30. 31.]]]
# MindSpore
import mindspore
from mindspore import Tensor
import numpy as np
input = Tensor(np.arange(32).reshape((1, 2, 4, 4)).astype(np.float32))
ms_unfold = mindspore.nn.Unfold(ksizes=(1, 2, 2, 1), rates=(1, 1, 1, 1), padding="valid", strides=(1, 1, 1, 1))
ms_output = ms_unfold(input)
# ms_output.shape = (1, 8, 3, 3)
print(ms_output.asnumpy())
# [[[[ 0. 1. 2.]
# [ 4. 5. 6.]
# [ 8. 9. 10.]]
# [[16. 17. 18.]
# [20. 21. 22.]
# [24. 25. 26.]]
# [[ 1. 2. 3.]
# [ 5. 6. 7.]
# [ 9. 10. 11.]]
# [[17. 18. 19.]
# [21. 22. 23.]
# [25. 26. 27.]]
# [[ 4. 5. 6.]
# [ 8. 9. 10.]
# [12. 13. 14.]]
# [[20. 21. 22.]
# [24. 25. 26.]
# [28. 29. 30.]]
# [[ 5. 6. 7.]
# [ 9. 10. 11.]
# [13. 14. 15.]]
# [[21. 22. 23.]
# [25. 26. 27.]
# [29. 30. 31.]]]]
out_batch, out_depth, out_row, out_col = ms_output.shape
ms_reshape = mindspore.ops.Reshape()
ms_output = ms_reshape(ms_output, (out_batch, out_depth, out_row * out_col))
# ms_output.shape = (1, 8, 9)
print(ms_output.asnumpy())
# [[[ 0. 1. 2. 4. 5. 6. 8. 9. 10.]
# [16. 17. 18. 20. 21. 22. 24. 25. 26.]
# [ 1. 2. 3. 5. 6. 7. 9. 10. 11.]
# [17. 18. 19. 21. 22. 23. 25. 26. 27.]
# [ 4. 5. 6. 8. 9. 10. 12. 13. 14.]
# [20. 21. 22. 24. 25. 26. 28. 29. 30.]
# [ 5. 6. 7. 9. 10. 11. 13. 14. 15.]
# [21. 22. 23. 25. 26. 27. 29. 30. 31.]]]
ms_concat = mindspore.ops.Concat()
output = None
for i in range(out_batch):
odd = None
even = None
for j in range(out_depth):
data = ms_output[i,j,:]
data = ms_reshape(data, (1, data.shape[0]))
if j % 2 == 0:
if even is None:
even = data
else:
even = ms_concat((even, data))
else:
if odd is None:
odd = data
else:
odd = ms_concat((odd, data))
temp = ms_concat((even, odd))
temp = ms_reshape(temp, (1, temp.shape[0], temp.shape[1]))
if i == 0:
output = temp
else:
output = ms_concat((output, temp))
ms_output = output
print(ms_output.asnumpy())
# [[[ 0. 1. 2. 4. 5. 6. 8. 9. 10.]
# [ 1. 2. 3. 5. 6. 7. 9. 10. 11.]
# [ 4. 5. 6. 8. 9. 10. 12. 13. 14.]
# [ 5. 6. 7. 9. 10. 11. 13. 14. 15.]
# [16. 17. 18. 20. 21. 22. 24. 25. 26.]
# [17. 18. 19. 21. 22. 23. 25. 26. 27.]
# [20. 21. 22. 24. 25. 26. 28. 29. 30.]
# [21. 22. 23. 25. 26. 27. 29. 30. 31.]]]
代码示例2
在PyTorch的padding参数取默认值0,MindSpore的padding取默认值"valid"前提下,当其余三个参数相对应设置时,若kernel_size为一个int时即kernel_size=a时,将ksizes设置为(1, a, a, 1);若kernel_size为一个tuple时即kernel_size=(a,b)时,将ksizes设置为(1, a, b, 1);若stride为一个int时即stride=a时,将strides设置为(1, a, a, 1);若stride为一个tuple时即stride=(a,b)时,将strides设置为(1, a, b, 1);若dilation为一个int时即dilation=a时,将rates设置为(1, a, a, 1);若dilation为一个tuple时即dilation=(a,b)时,将rates设置为(1, a, b, 1)。其次为了输出结果完全一致,首先将MindSpore输出结果进行Reshape操作,然后通过下面的操作进行Concat得到最终结果。
# PyTorch
import torch
from torch import Tensor
import numpy as np
input = Tensor(np.arange(32).reshape((1, 2, 4, 4)).astype(np.float32))
torch_unfold = torch.nn.Unfold(kernel_size=(2,2), dilation=(1, 1), padding=0, stride=(1, 1))
torch_output = torch_unfold(input)
# torch_output.shape=(1, 8, 9)
print(torch_output.numpy())
# [[[ 0. 1. 2. 4. 5. 6. 8. 9. 10.]
# [ 1. 2. 3. 5. 6. 7. 9. 10. 11.]
# [ 4. 5. 6. 8. 9. 10. 12. 13. 14.]
# [ 5. 6. 7. 9. 10. 11. 13. 14. 15.]
# [16. 17. 18. 20. 21. 22. 24. 25. 26.]
# [17. 18. 19. 21. 22. 23. 25. 26. 27.]
# [20. 21. 22. 24. 25. 26. 28. 29. 30.]
# [21. 22. 23. 25. 26. 27. 29. 30. 31.]]]
# MindSpore
import mindspore
from mindspore import Tensor
import numpy as np
input = Tensor(np.arange(32).reshape((1, 2, 4, 4)).astype(np.float32))
ms_unfold = mindspore.nn.Unfold(ksizes=(1, 2, 2, 1), rates=(1, 1, 1, 1), padding="valid", strides=(1, 1, 1, 1))
ms_output = ms_unfold(input)
# ms_output.shape = (1, 8, 3, 3)
print(ms_output.asnumpy())
# [[[[ 0. 1. 2.]
# [ 4. 5. 6.]
# [ 8. 9. 10.]]
# [[16. 17. 18.]
# [20. 21. 22.]
# [24. 25. 26.]]
# [[ 1. 2. 3.]
# [ 5. 6. 7.]
# [ 9. 10. 11.]]
# [[17. 18. 19.]
# [21. 22. 23.]
# [25. 26. 27.]]
# [[ 4. 5. 6.]
# [ 8. 9. 10.]
# [12. 13. 14.]]
# [[20. 21. 22.]
# [24. 25. 26.]
# [28. 29. 30.]]
# [[ 5. 6. 7.]
# [ 9. 10. 11.]
# [13. 14. 15.]]
# [[21. 22. 23.]
# [25. 26. 27.]
# [29. 30. 31.]]]]
out_batch, out_depth, out_row, out_col = ms_output.shape
ms_reshape = mindspore.ops.Reshape()
ms_output = ms_reshape(ms_output, (out_batch, out_depth, out_row * out_col))
# ms_output.shape = (1, 8, 9)
print(ms_output.asnumpy())
# [[[ 0. 1. 2. 4. 5. 6. 8. 9. 10.]
# [16. 17. 18. 20. 21. 22. 24. 25. 26.]
# [ 1. 2. 3. 5. 6. 7. 9. 10. 11.]
# [17. 18. 19. 21. 22. 23. 25. 26. 27.]
# [ 4. 5. 6. 8. 9. 10. 12. 13. 14.]
# [20. 21. 22. 24. 25. 26. 28. 29. 30.]
# [ 5. 6. 7. 9. 10. 11. 13. 14. 15.]
# [21. 22. 23. 25. 26. 27. 29. 30. 31.]]]
ms_concat = mindspore.ops.Concat()
output = None
for i in range(out_batch):
odd = None
even = None
for j in range(out_depth):
data = ms_output[i,j,:]
data = ms_reshape(data, (1, data.shape[0]))
if j % 2 == 0:
if even is None:
even = data
else:
even = ms_concat((even, data))
else:
if odd is None:
odd = data
else:
odd = ms_concat((odd, data))
temp = ms_concat((even, odd))
temp = ms_reshape(temp, (1, temp.shape[0], temp.shape[1]))
if i == 0:
output = temp
else:
output = ms_concat((output, temp))
ms_output = output
print(ms_output.asnumpy())
# [[[ 0. 1. 2. 4. 5. 6. 8. 9. 10.]
# [ 1. 2. 3. 5. 6. 7. 9. 10. 11.]
# [ 4. 5. 6. 8. 9. 10. 12. 13. 14.]
# [ 5. 6. 7. 9. 10. 11. 13. 14. 15.]
# [16. 17. 18. 20. 21. 22. 24. 25. 26.]
# [17. 18. 19. 21. 22. 23. 25. 26. 27.]
# [20. 21. 22. 24. 25. 26. 28. 29. 30.]
# [21. 22. 23. 25. 26. 27. 29. 30. 31.]]]
代码示例3
PyTorch的padding表示输入的两侧进行零填充,支持int和tuple输入,默认值是0,与MindSpore的padding取默认值"valid"相对应。当PyTorch的padding取其他值时,MindSpore的padding的可取值只有"valid"和"same",所以没有与之对应的取值,故输出结果不一致。
# PyTorch
import torch
from torch import Tensor
import numpy as np
input = Tensor(np.arange(32).reshape((1, 2, 4, 4)).astype(np.float32))
torch_unfold = torch.nn.Unfold(kernel_size=(2,2), dilation=1, padding=1, stride=1)
torch_output = torch_unfold(input)
# ms_output.shape = (1, 8, 25)
print(torch_output.numpy())
# [[[ 0. 0. 0. 0. 0. 0. 0. 1. 2. 3. 0. 4. 5. 6. 7. 0. 8.
# 9. 10. 11. 0. 12. 13. 14. 15.]
# [ 0. 0. 0. 0. 0. 0. 1. 2. 3. 0. 4. 5. 6. 7. 0. 8. 9.
# 10. 11. 0. 12. 13. 14. 15. 0.]
# [ 0. 0. 1. 2. 3. 0. 4. 5. 6. 7. 0. 8. 9. 10. 11. 0. 12.
# 13. 14. 15. 0. 0. 0. 0. 0.]
# [ 0. 1. 2. 3. 0. 4. 5. 6. 7. 0. 8. 9. 10. 11. 0. 12. 13.
# 14. 15. 0. 0. 0. 0. 0. 0.]
# [ 0. 0. 0. 0. 0. 0. 16. 17. 18. 19. 0. 20. 21. 22. 23. 0. 24.
# 25. 26. 27. 0. 28. 29. 30. 31.]
# [ 0. 0. 0. 0. 0. 16. 17. 18. 19. 0. 20. 21. 22. 23. 0. 24. 25.
# 26. 27. 0. 28. 29. 30. 31. 0.]
# [ 0. 16. 17. 18. 19. 0. 20. 21. 22. 23. 0. 24. 25. 26. 27. 0. 28.
# 29. 30. 31. 0. 0. 0. 0. 0.]
# [16. 17. 18. 19. 0. 20. 21. 22. 23. 0. 24. 25. 26. 27. 0. 28. 29.
# 30. 31. 0. 0. 0. 0. 0. 0.]]]
# MindSpore
import numpy as np
import mindspore
from mindspore import Tensor
# MindSpore
import mindspore
from mindspore import Tensor
import numpy as np
input = Tensor(np.arange(32).reshape((1, 2, 4, 4)).astype(np.float32))
ms_unfold = mindspore.nn.Unfold(ksizes=(1, 2, 2, 1), rates=(1, 1, 1, 1), padding="same", strides=(1, 1, 1, 1))
ms_output = ms_unfold(input)
# ms_output.shape = (1, 8, 4, 4)
print(ms_output.asnumpy())
# [[[[ 0. 1. 2. 3.]
# [ 4. 5. 6. 7.]
# [ 8. 9. 10. 11.]
# [12. 13. 14. 15.]]
# [[16. 17. 18. 19.]
# [20. 21. 22. 23.]
# [24. 25. 26. 27.]
# [28. 29. 30. 31.]]
# [[ 1. 2. 3. 0.]
# [ 5. 6. 7. 0.]
# [ 9. 10. 11. 0.]
# [13. 14. 15. 0.]]
# [[17. 18. 19. 0.]
# [21. 22. 23. 0.]
# [25. 26. 27. 0.]
# [29. 30. 31. 0.]]
# [[ 4. 5. 6. 7.]
# [ 8. 9. 10. 11.]
# [12. 13. 14. 15.]
# [ 0. 0. 0. 0.]]
# [[20. 21. 22. 23.]
# [24. 25. 26. 27.]
# [28. 29. 30. 31.]
# [ 0. 0. 0. 0.]]
# [[ 5. 6. 7. 0.]
# [ 9. 10. 11. 0.]
# [13. 14. 15. 0.]
# [ 0. 0. 0. 0.]]
# [[21. 22. 23. 0.]
# [25. 26. 27. 0.]
# [29. 30. 31. 0.]
# [ 0. 0. 0. 0.]]]]
out_batch, out_depth, out_row, out_col = ms_output.shape
ms_reshape = mindspore.ops.Reshape()
ms_output = ms_reshape(ms_output, (out_batch, out_depth, out_row * out_col))
# ms_output.shape = (1, 8, 16)
print(ms_output.asnumpy())
# [[[ 0. 1. 2. 4. 5. 6. 8. 9. 10.]
# [16. 17. 18. 20. 21. 22. 24. 25. 26.]
# [ 1. 2. 3. 5. 6. 7. 9. 10. 11.]
# [17. 18. 19. 21. 22. 23. 25. 26. 27.]
# [ 4. 5. 6. 8. 9. 10. 12. 13. 14.]
# [20. 21. 22. 24. 25. 26. 28. 29. 30.]
# [ 5. 6. 7. 9. 10. 11. 13. 14. 15.]
# [21. 22. 23. 25. 26. 27. 29. 30. 31.]]]
ms_concat = mindspore.ops.Concat()
output = None
for i in range(out_batch):
odd = None
even = None
for j in range(out_depth):
data = ms_output[i,j,:]
data = ms_reshape(data, (1, data.shape[0]))
if j % 2 == 0:
if even is None:
even = data
else:
even = ms_concat((even, data))
else:
if odd is None:
odd = data
else:
odd = ms_concat((odd, data))
temp = ms_concat((even, odd))
temp = ms_reshape(temp, (1, temp.shape[0], temp.shape[1]))
if i == 0:
output = temp
else:
output = ms_concat((output, temp))
ms_output = output
print(ms_output.asnumpy())
# [[[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15.]
# [ 1. 2. 3. 0. 5. 6. 7. 0. 9. 10. 11. 0. 13. 14. 15. 0.]
# [ 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 0. 0. 0. 0.]
# [ 5. 6. 7. 0. 9. 10. 11. 0. 13. 14. 15. 0. 0. 0. 0. 0.]
# [16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31.]
# [17. 18. 19. 0. 21. 22. 23. 0. 25. 26. 27. 0. 29. 30. 31. 0.]
# [20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 0. 0. 0. 0.]
# [21. 22. 23. 0. 25. 26. 27. 0. 29. 30. 31. 0. 0. 0. 0. 0.]]]