比较与torch.nn.Module.buffers的功能差异
torch.nn.Module.buffers
torch.nn.Module.buffers(recurse=True)
更多内容详见torch.nn.Module.buffers。
mindspore.nn.Cell.untrainable_params
mindspore.nn.Cell.untrainable_params(recurse=True)
使用方式
PyTorch中,网络有parameter
, buffer
, state
三种概念,其中state
为parameter
和buffer
的合集。parameter
可以通过requires_grad
属性来区分网络中的参数是否需要优化;buffer
多定义为网络中的不变量,例如在定义网络时,BN中的running_mean
和running_var
会被自动注册为buffer
;用户也可以通过相关接口自行注册parameter
和buffer
。
torch.nn.Module.buffers
: 获取网络中的buffer
,返回类型为生成器。torch.nn.Module.named_buffers
:获取网络中的buffer
名称和buffer
本身,返回类型为生成器。
MindSpore中目前只有parameter
的概念,通过requires_grad
属性来区分网络中的参数是否需要优化,例如在定义网络时,BN中的moving_mean
和moving_var
会被定义为requires_grad=False
的parameter
。
mindspore.nn.Cell.untrainable_params
:获取网络中不需要被优化器优化的参数,返回类型为列表。MindSpore中的Parameter
含有属性name
,在使用untrainable_params
方法获取参数后,可以使用此属性获取名称。
代码示例
import mindspore
import numpy as np
from mindspore import nn
class ConvBN(nn.Cell):
def __init__(self):
super(ConvBN, self).__init__()
self.conv = nn.Conv2d(3, 64, 3)
self.bn = nn.BatchNorm2d(64)
def construct(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class MyNet(nn.Cell):
def __init__(self):
super(MyNet, self).__init__()
self.build_block = nn.SequentialCell(ConvBN(), nn.ReLU())
def construct(self, x):
return self.build_block(x)
# The following implements mindspore.nn.Cell.untrainable_params() with MindSpore.
net = MyNet()
print(type(net.untrainable_params()), "\n")
for params in net.untrainable_params():
print("Name: ", params.name)
print("params: ", params)
# Out:
<class 'list'>
Name: build_block.0.bn.moving_mean
params: Parameter (name=build_block.0.bn.moving_mean, shape=(64,), dtype=Float32, requires_grad=False)
Name: build_block.0.bn.moving_variance
params: Parameter (name=build_block.0.bn.moving_variance, shape=(64,), dtype=Float32, requires_grad=False)
import torch.nn as nn
class ConvBN(nn.Module):
def __init__(self):
super(ConvBN, self).__init__()
self.conv = nn.Conv2d(3, 64, 3)
self.bn = nn.BatchNorm2d(64)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.build_block = nn.Sequential(ConvBN(), nn.ReLU())
def construct(self, x):
return self.build_block(x)
# The following implements torch.nn.Module.buffers() with torch.
net = MyNet()
print(type(net.buffers()), "\n")
for name, params in net.named_buffers():
print("Name: ", name)
print("params: ", params.size())
# Out:
<class 'generator'>
Name: build_block.0.bn.running_mean
params: torch.Size([64])
Name: build_block.0.bn.running_var
params: torch.Size([64])
Name: build_block.0.bn.num_batches_tracked
params: torch.Size([])