mindspore.nn.Cell

查看源文件
class mindspore.nn.Cell(auto_prefix=True, flags=None)[源代码]

MindSpore中神经网络的基本构成单元。模型或神经网络层应当继承该基类。

mindspore.nn 中神经网络层也是Cell的子类,如 mindspore.nn.Conv2dmindspore.nn.ReLU 等。Cell在GRAPH_MODE(静态图模式)下将编译为一张计算图,在PYNATIVE_MODE(动态图模式)下作为神经网络的基础模块。

说明

Cell默认情况下是推理模式。对于继承Cell的类,如果训练和推理具有不同结构,子类会默认执行推理分支。设置训练模式,请参考 mindspore.nn.Cell.set_train

警告

在Cell的子类中不能定义名为'cast'的方法,不能定义名为'phase'和'cells'的属性, 否则会报错。

参数:
  • auto_prefix (bool,可选) - 是否自动为Cell及其子Cell生成NameSpace。该参数同时会影响 Cell 中权重参数的名称。如果设置为 True ,则自动给权重参数的名称添加前缀,否则不添加前缀。通常情况下,骨干网络应设置为 True ,否则会产生重名问题。用于训练骨干网络的优化器、 mindspore.nn.TrainOneStepCell 等,应设置为 False ,否则骨干网络的权重参数名会被误改。默认值: True

  • flags (dict,可选) - Cell的配置信息,目前用于绑定Cell和数据集。用户也可通过该参数自定义Cell属性。默认值: None

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore.nn as nn
>>> from mindspore import ops
>>> class MyCell(nn.Cell):
...     def __init__(self, forward_net):
...         super(MyCell, self).__init__(auto_prefix=False)
...         self.net = forward_net
...         self.relu = ops.ReLU()
...
...     def construct(self, x):
...         y = self.net(x)
...         return self.relu(y)
>>>
>>> inner_net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
>>> my_net = MyCell(inner_net)
>>> print(my_net.trainable_params())
... # If the 'auto_prefix' set to True or not set when call the '__init__' method of the parent class,
... # the parameter's name will be 'net.weight'.
[Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]
add_flags(**flags)[源代码]

为Cell添加自定义属性。

在实例化Cell类时,如果入参flags不为空,会调用此方法。

参数:
  • flags (dict) - Cell的配置信息,目前用于绑定Cell和数据集。用户也可通过该参数自定义Cell属性。

样例:

>>> import mindspore as ms
>>> from mindspore import Tensor, nn
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.relu = nn.ReLU()
...
...     def construct(self, x):
...         x = self.relu(x)
...         return x
>>> net = Net()
>>> net.add_flags(sink_mode=True)
>>> print(net.sink_mode)
True
add_flags_recursive(**flags)[源代码]

如果Cell含有多个子Cell,此方法会递归地给所有子Cell添加自定义属性。

参数:
  • flags (dict) - Cell的配置信息,目前用于绑定Cell和数据集。用户也可通过该参数自定义Cell属性。

样例:

>>> import mindspore as ms
>>> from mindspore import Tensor, nn
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.relu = nn.ReLU()
...
...     def construct(self, x):
...         x = self.relu(x)
...         return x
>>> net = Net()
>>> net.add_flags_recursive(sink_mode=True)
>>> print(net.sink_mode)
True
apply(fn)[源代码]

递归地将 fn 应用于每个子Cell(由 .cells() 返回)以及自身。通常用于初始化模型的参数。

参数:
  • fn (function) - 被执行于每个Cell的function。

返回:

Cell类型,Cell本身。

样例:

>>> import mindspore.nn as nn
>>> from mindspore.common.initializer import initializer, One
>>> net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
>>> def func(cell):
...     if isinstance(cell, nn.Dense):
...         cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
>>> net.apply(func)
SequentialCell(
  (0): Dense(input_channels=2, output_channels=2, has_bias=True)
  (1): Dense(input_channels=2, output_channels=2, has_bias=True)
)
>>> print(net[0].weight.asnumpy())
[[1. 1.]
 [1. 1.]]
auto_cast_inputs(inputs)[源代码]

在混合精度下,自动对输入进行类型转换。

参数:
  • inputs (tuple) - construct方法的输入。

返回:

Tuple类型,经过类型转换后的输入。

property bprop_debug

在图模式下使用,用于标识是否使用自定义的反向传播函数。

buffers(recurse=True)[源代码]

返回Cell缓冲区的迭代器,只包含缓冲区本身。

参数:
  • recurse (bool,可选) - 如果为 True ,则生成此Cell及其子Cell的缓冲区。否则,仅生成此Cell的缓冲区。默认 True

返回:

Iterator[Tensor],缓冲区的迭代器。

样例:

>>> import mindspore
...
...
>>> class NetB(mindspore.nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.buffer_b = mindspore.nn.Buffer(mindspore.tensor([1, 2, 3]))
...
...     def construct(self, x):
...         return x + self.buffer_b
...
...
>>> class NetA(mindspore.nn.Cell):
...     def __init__(self, net_b):
...         super().__init__()
...         self.net_b = net_b
...         self.buffer_a = mindspore.nn.Buffer(mindspore.tensor([4, 5, 6]))
...
...     def construct(self, x):
...         return self.net_b(x) + self.buffer_a
...
...
>>> net_b = NetB()
>>> net_a = NetA(net_b)
>>>
>>> for buffer in net_a.buffers():
>>>     print(f'buffer is {buffer}')
buffer is [4, 5, 6]
buffer is [1, 2, 3]
cast_inputs(inputs, dst_type)[源代码]

将输入转换为指定类型。

参数:
  • inputs (tuple[Tensor]) - 输入。

  • dst_type (mindspore.dtype) - 指定的数据类型。

返回:

tuple[Tensor]类型,转换类型后的结果。

cast_param(param)[源代码]

在PyNative模式下,根据自动混合精度的精度设置转换Cell中参数的类型。

该接口目前在自动混合精度场景下使用。

参数:
  • param (Parameter) - 需要被转换类型的输入参数。

返回:

Parameter类型,转换类型后的参数。

cells()[源代码]

返回当前Cell的子Cell的迭代器。

返回:

Iteration类型,Cell的子Cell。

样例:

>>> import mindspore as ms
>>> from mindspore import Tensor, nn
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.dense = nn.Dense(2, 2)
...
...     def construct(self, x):
...         x = self.dense(x)
...         return x
>>> net = Net()
>>> print(net.cells())
odict_values([Dense(input_channels=2, output_channels=2, has_bias=True)])
cells_and_names(cells=None, name_prefix='')[源代码]

递归地获取当前Cell及输入 cells 的所有子Cell的迭代器,包括Cell的名称及其本身。

参数:
  • cells (str) - 需要进行迭代的Cell。默认值: None

  • name_prefix (str) - 作用域。默认值: ''

返回:

Iteration类型,当前Cell及输入 cells 的所有子Cell和相对应的名称。

样例:

>>> from mindspore import nn
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.conv = nn.Conv2d(3, 64, 3)
...     def construct(self, x):
...         out = self.conv(x)
...         return out
>>> names = []
>>> n = Net()
>>> for m in n.cells_and_names():
...     if m[0]:
...         names.append(m[0])
check_names()[源代码]

检查Cell中的网络参数名称是否重复。

compile(*args, **kwargs)[源代码]

编译Cell为计算图,输入需与construct中定义的输入一致。

参数:
  • args (tuple) - Cell的输入。

  • kwargs (dict) - Cell的输入。

compile_and_run(*args, **kwargs)[源代码]

编译并运行Cell,输入需与construct中定义的输入一致。

说明

不推荐使用该函数,建议直接调用Cell实例。

参数:
  • args (tuple) - Cell的输入。

  • kwargs (dict) - Cell的输入。

返回:

Object类型,执行的结果。

construct(*args, **kwargs)[源代码]

定义要执行的计算逻辑。所有子类都必须重写此方法。

说明

当前不支持inputs同时输入tuple类型和非tuple类型。

参数:
  • args (tuple) - 可变参数列表,默认值: ()

  • kwargs (dict) - 可变的关键字参数的字典,默认值: {}

返回:

Tensor类型,返回计算结果。

extend_repr()[源代码]

在原有描述基础上扩展Cell的描述。

若需要在print时输出个性化的扩展信息,请在您的网络中重新实现此方法。

flatten_weights(fusion_size=0)[源代码]

重置权重参数(即可训练参数)使用的数据内存,让这些参数按数据类型分组使用连续内存块。

说明

默认情况下,具有相同数据类型的参数会使用同一个连续内存块。但对于某些具有大量参数的模型, 将一个大的连续内存块分为多个小一点的内存块,有可能提升性能,对于这种情况, 可以通过 fusion_size 参数来限制最大连续内存块的的大小。

参数:
  • fusion_size (int) - 最大连续内存块的大小(以字节为单位), 0 表示不限制大小。默认值: 0

generate_scope()[源代码]

为网络中的每个Cell对象生成NameSpace。

get_buffer(target)[源代码]

返回给定 target 的缓冲区,如果不存在则抛出错误。

请参阅 get_sub_cell 的文档,了解有关此方法功能的更详细说明以及如何正确指定 target

参数:
  • target (str) - 要查找的缓冲区的完全限定字符串名称。(请参阅 get_sub_cell 了解如何指定完全限定字符串。)

返回:

Tensor

样例:

>>> import mindspore
...
...
>>> class NetC(mindspore.nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.buffer_c = mindspore.nn.Buffer(mindspore.tensor([0, 0, 0]))
...
...     def construct(self, x):
...         return x + self.buffer_c
...
...
>>> class NetB(mindspore.nn.Cell):
...     def __init__(self, net_c):
...         super().__init__()
...         self.net_c = net_c
...         self.buffer_b = mindspore.nn.Buffer(mindspore.tensor([1, 2, 3]))
...
...     def construct(self, x):
...         return self.net_c(x) + self.buffer_b
...
...
>>> class NetA(mindspore.nn.Cell):
...     def __init__(self, net_b):
...         super().__init__()
...         self.net_b = net_b
...         self.buffer_a = mindspore.nn.Buffer(mindspore.tensor([4, 5, 6]))
...
...     def construct(self, x):
...         return self.net_b(x) + self.buffer_a
...
...
>>> net_c = NetC()
>>> net_b = NetB(net_c)
>>> net_a = NetA(net_b)
>>> buffer_c = net_a.get_buffer("net_b.net_c.buffer_c")
>>> print(f'buffer_c is {buffer_c}')
buffer_c is [0, 0, 0]
get_extra_state()[源代码]

返回要包含在Cell的 state_dict 中的任何额外状态。

当构建Cell的 state_dict() 时,将调用此函数。 如果您需要存储额外状态,实现此方法,并为您的Cell实现相应的 set_extra_state()

请注意,额外状态应为可序列化对象(picklable),以确保state_dict的序列化可用性。 仅对tensor的序列化提供向后兼容性保证; 对于其他对象,如果其序列化的pickled形式发生变化,可能会导致向后兼容性问题。

返回:

object,要存储在Cell的state_dict中的额外状态。

get_flags()[源代码]

获取该Cell的自定义属性,自定义属性通过 add_flags 方法添加。

样例:

>>> import mindspore as ms
>>> from mindspore import Tensor, nn
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.relu = nn.ReLU()
...
...     def construct(self, x):
...         x = self.relu(x)
...         return x
>>> net = Net()
>>> net.add_flags(sink_mode=True)
>>> print(net.get_flags())
{'sink_mode':True}
get_func_graph_proto()[源代码]

返回图的二进制原型。

get_inputs()[源代码]

返回编译计算图所设置的输入。

返回:

Tuple类型,编译计算图所设置的输入。

警告

这是一个实验性API,后续可能修改或删除。

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import nn, Tensor
>>>
>>> class ReluNet(nn.Cell):
...     def __init__(self):
...         super(ReluNet, self).__init__()
...         self.relu = nn.ReLU()
...     def construct(self, x):
...         return self.relu(x)
>>>
>>> net = ReluNet()
>>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
>>> net.set_inputs(input_dyn)
>>> get_inputs = net.get_inputs()
>>> print(get_inputs)
(Tensor(shape=[3, -1], dtype=Float32, value= ),)
get_parameters(expand=True)[源代码]

返回Cell中parameter的迭代器。

获取Cell的参数。如果 expandtrue ,获取此cell和所有subcells的参数。关于subcell,请看下面的示例。

参数:
  • expand (bool) - 如果为 True ,则递归地获取当前Cell和所有子Cell的parameter。否则,只生成当前Cell的subcell的parameter。默认值: True

返回:

Iteration类型,Cell的parameter。

样例:

>>> import mindspore as ms
>>> from mindspore import nn, ops, Tensor
>>> import numpy as np
>>> class TestNet(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.my_w1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
...         self.my_w2 = ms.Parameter(Tensor(np.ones([16]), ms.float32))
...     def construct(self, x):
...         x += self.my_w1
...         x = ops.reshape(x, (16,)) - self.my_w2
...         return x
>>> class TestNet2(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.my_t1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
...         # self.subcell is a subcell of TestNet2, when using expand=True, the parameters of TestNet will
...         # also be gathered.
...         self.subcell = TestNet()
...     def construct(self, x):
...         x += self.my_w1
...         x = ops.reshape(x, (16,)) - self.my_w2
...         return x
>>> net = TestNet2()
>>> print([p for p in net.get_parameters(expand=True)])
[Parameter (name=my_t1, shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w1,
shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w2, shape=(16,), dtype=Float32,
requires_grad=True)]
get_scope()[源代码]

返回Cell的作用域。

返回:

String类型,网络的作用域。

get_sub_cell(target)[源代码]

返回给定 target 的子Cell,如果不存在则抛出错误。

例如,假设你有一个 nn.Cell A,如下所示:

A(
    (net_b): NetB(
        (net_c): NetC(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        )
        (dense): Dense(in_features=100, out_features=200, bias=True)
    )
)

(该图显示了 nn.Cell AA 有一个嵌套的子Cell`net_b`, 而后者本身又有两个子Cell net_cdensenet_c 则有一个子Cell conv 。)

要检查是否拥有子Cell dense ,我们将调用 get_sub_cell("net_b.dense") 。要检查是否拥有子Cell conv ,我们将调用 get_sub_cell("net_b.net_c.conv")

get_sub_cell 的运行时间受 target 中Cell嵌套程度的限制。使用 name_cells 的查询可获得相同的结果,但传递的Cell的数量级为O(N)。 因此,为了简单检查是否存在某些子Cell,应始终使用 get_sub_cell

参数:
  • target (str) - 要查找的子Cell的完全限定字符串名称。(请参阅上述示例以了解如何指定完全限定字符串。)

返回:

Cell

样例:

>>> import mindspore
...
...
>>> class NetC(mindspore.nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.buffer_c = mindspore.nn.Buffer(mindspore.tensor([0, 0, 0]))
...         self.dense_c = mindspore.nn.Dense(5, 3)
...
...     def construct(self, x):
...         return self.dense_c(x) + self.buffer_c
...
...
>>> class NetB(mindspore.nn.Cell):
...     def __init__(self, net_c):
...         super().__init__()
...         self.net_c = net_c
...         self.buffer_b = mindspore.nn.Buffer(mindspore.tensor([1, 2, 3]))
...
...     def construct(self, x):
...         return self.net_c(x) + self.buffer_b
...
...
>>> class NetA(mindspore.nn.Cell):
...     def __init__(self, net_b):
...         super().__init__()
...         self.net_b = net_b
...         self.buffer_a = mindspore.nn.Buffer(mindspore.tensor([4, 5, 6]))
...
...     def construct(self, x):
...         return self.net_b(x) + self.buffer_a
...
...
>>> net_c = NetC()
>>> net_b = NetB(net_c)
>>> net_a = NetA(net_b)
>>> net_c = net_a.get_sub_cell("net_b.net_c")
>>> print(f'net_c is {net_c}')
net_c is NetC(
    (dense_c): Dense(input_channels=5, output_channels=3, has_bias=True)
)
infer_param_pipeline_stage()[源代码]

推导Cell中当前 pipeline_stage 的参数。

说明

  • 这个接口在2.3版本废弃,并且会在未来版本移除。

返回:

属于当前 pipeline_stage 的参数。

异常:
  • RuntimeError - 如果参数不属于任何stage。

init_parameters_data(auto_parallel_mode=False)[源代码]

初始化并替换Cell中所有的parameter的值。

说明

在调用 init_parameters_data 后,trainable_params() 或其他相似的接口可能返回不同的参数对象,不建议保存这些结果。

参数:
  • auto_parallel_mode (bool) - 是否在自动并行模式下执行。默认值: False

返回:

Dict[Parameter, Parameter],返回一个原始参数和替换参数的字典。

样例:

>>> import mindspore as ms
>>> from mindspore import Tensor, nn
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.dense = nn.Dense(2, 2)
...
...     def construct(self, x):
...         x = self.dense(x)
...         return x
>>> net = Net()
>>> print(net.init_parameters_data())
{Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True):
 Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True),
 Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True):
 Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True)}
insert_child_to_cell(child_name, child_cell)[源代码]

将一个给定名称的子Cell添加到当前Cell。

参数:
  • child_name (str) - 子Cell名称。

  • child_cell (Cell) - 要插入的子Cell。

异常:
  • KeyError - 如果子Cell的名称不正确或与其他子Cell名称重复。

  • TypeError - 如果 child_name 的类型不为str类型。

  • TypeError - 如果子Cell的类型不正确。

样例:

>>> import mindspore as ms
>>> from mindspore import Tensor, nn
...
>>> net1 = nn.ReLU()
>>> net2 = nn.Dense(2, 2)
>>> net1.insert_child_to_cell("child", net2)
>>> print(net1)
ReLU(
  (child): Dense(input_channels=2, output_channels=2, has_bias=True)
)
insert_param_to_cell(param_name, param, check_name_contain_dot=True)[源代码]

向当前Cell添加参数。

将指定名称的参数添加到Cell中。目前在 mindspore.nn.Cell.__setattr__ 中使用。

参数:
  • param_name (str) - 参数名称。

  • param (Parameter) - 要插入到Cell的参数。

  • check_name_contain_dot (bool) - 是否对 param_name 中的"."进行检查。默认值: True

异常:
  • KeyError - 如果参数名称为空或包含"."。

  • TypeError - 如果参数的类型不是Parameter。

样例:

>>> import mindspore as ms
>>> from mindspore import Tensor, nn, Parameter
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.relu = nn.ReLU()
...
...     def construct(self, x):
...         x = self.relu(x)
...         return x
>>> net = Net()
>>> net.insert_param_to_cell("bias", Parameter(Tensor([1, 2, 3])))
>>> print(net.bias)
Parameter(name=bias, shape=(3,), dtype=Int64, requires_grad=True)
load_state_dict(state_dict, strict=True)[源代码]

state_dict 中的参数和缓冲区复制到当前Cell及其子Cell中。

如果 strict 设置为 True ,则 state_dict 的键必须与该Cell的 mindspore.nn.Cell.state_dict() 方法返回的键完全匹配。

参数:
  • state_dict (dict) - 包含参数和持久缓冲区的字典。

  • strict (bool,可选) - 是否严格要求输入 state_dict 中的键必须与当前Cell的 mindspore.nn.Cell.state_dict() 方法返回的键匹配。默认 True

返回:

一个包含 missing_keysunexpected_keys 字段的namedtuple,

  • missing_keys 是一个包含字符串的列表,表示当前Cell需要但在state_dict中缺失的键。

  • unexpected_keys 是一个包含字符串的列表,表示state_dict中存在但当前Cell不需要的键。

说明

如果某个参数或缓冲区被注册为 None ,但其对应的键在 state_dict 中存在,则 mindspore.nn.Cell.load_state_dict() 将会抛出 RuntimeError

样例:

>>> import mindspore
>>> import os
>>> class Model(mindspore.nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.buffer_a = mindspore.nn.Buffer(mindspore.tensor([4, 5, 6]))
...         self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3]))
...
...     def construct(self, x):
...         return x + self.buffer_a + self.param_a
...
...
>>> model = Model()
>>> print(model.state_dict())
>>> mindspore.save_checkpoint(model.state_dict(), './model_state_dict_ckpt')
>>> new_model = Model()
>>> new_model.load_state_dict(mindspore.load_checkpoint('./model_state_dict_ckpt'))
>>> print(new_model.state_dict())
>>> os.remove('./model_state_dict_ckpt')
OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
name_cells()[源代码]

递归地获取一个Cell中所有子Cell的迭代器。

包括Cell名称和Cell本身。

返回:

Dict[String, Cell],Cell中的所有子Cell及其名称。

样例:

>>> import mindspore as ms
>>> from mindspore import Tensor, nn
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.dense = nn.Dense(2, 2)
...
...     def construct(self, x):
...         x = self.dense(x)
...         return x
>>> net = Net()
>>> print(net.name_cells())
OrderedDict([('dense', Dense(input_channels=2, output_channels=2, has_bias=True))])
named_buffers(prefix='', recurse=True, remove_duplicate=True)[源代码]

返回Cell中缓冲区的迭代器,包含缓冲区的名称以及缓冲区本身。

参数:
  • prefix (str,可选) - 添加到所有缓冲区名称前面的前缀。默认 ""

  • recurse (bool,可选) - 如果为 True ,则生成此Cell和所有子Cell的缓冲区。否则,仅生成此Cell的缓冲区。默认 True

  • remove_duplicate (bool,可选) - 是否删除结果中的重复缓冲区。默认 True

返回:

Iterator[Tuple[str, Tensor]],包含名称和缓冲区的元组的迭代器。

样例:

>>> import mindspore
...
...
>>> class NetB(mindspore.nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.buffer_b = mindspore.nn.Buffer(mindspore.tensor([1, 2, 3]))
...
...     def construct(self, x):
...         return x + self.buffer_b
...
...
>>> class NetA(mindspore.nn.Cell):
...     def __init__(self, net_b):
...         super().__init__()
...         self.net_b = net_b
...         self.buffer_a = mindspore.nn.Buffer(mindspore.tensor([4, 5, 6]))
...
...     def construct(self, x):
...         return self.net_b(x) + self.buffer_a
...
...
>>> net_b = NetB()
>>> net_a = NetA(net_b)
>>>
>>> for name, buffer in net_a.named_buffers():
>>>     print(f'buffer name is {name}, buffer is {buffer}')
buffer name is buffer_a, buffer is [4, 5, 6]
buffer name is net_b.buffer_b, buffer is [1, 2, 3]
offload(backward_prefetch='Auto')[源代码]

设置Cell激活值卸载,设置后该Cell中所有的Primitive类会被使能激活值卸载标签。若激活值需要在反向阶段被用于计算 梯度,则该激活值会在正向阶段被搬运至host侧,不会存储在device侧,并在反向阶段使用其之前,预取回device侧。

说明

  • 当某个Cell被标记为offload时,运行模型必须为"GRAPH_MODE"模式。

  • 当某个Cell被标记为offload时,需要开启lazyinline。

参数:
  • backward_prefetch (Union[str, int],可选) - 设置反向阶段提前预取激活值的时机。默认值: "Auto" 。当为 "Auto" 时,框架将提前一个算子开始预取激活值;当为正整数时,框架将提前 backward_prefetch 个算子开始预期激活值,例如1、20、100。

样例:

>>> import mindspore.nn as nn
>>> from mindspore import ops
>>> from mindspore.common import Tensor, Parameter
>>>
>>> class Block(nn.Cell):
...     def __init__(self):
...         super(Block, self).__init__()
...         self.transpose1 = ops.Transpose()
...         self.transpose2 = ops.Transpose()
...         self.transpose3 = ops.Transpose()
...         self.transpose4 = ops.Transpose()
...         self.real_div1 = ops.RealDiv()
...         self.real_div2 = ops.RealDiv()
...         self.batch_matmul1 = ops.BatchMatMul()
...         self.batch_matmul2 = ops.BatchMatMul()
...         self.softmax = ops.Softmax(-1)
...         self.expand_dims = ops.ExpandDims()
...         self.sub = ops.Sub()
...         self.y = Parameter(Tensor(np.ones((1024, 128, 128)).astype(np.float32)))
...     def construct(self, x):
...         transpose1 = self.transpose1(x, (0, 2, 1, 3))
...         real_div1 = self.real_div1(transpose1, Tensor(2.37891))
...         transpose2 = self.transpose2(x, (0, 2, 3, 1))
...         real_div2 = self.real_div2(transpose2, Tensor(2.37891))
...         batch_matmul1 = self.batch_matmul1(real_div1, real_div2)
...         expand_dims = self.expand_dims(self.y, 1)
...         sub = self.sub(Tensor([1.0]), expand_dims)
...         soft_max = self.softmax(sub)
...         transpose3 = self.transpose3(x, (0, 2, 1, 3))
...         batch_matmul2 = self.batch_matmul2(soft_max[0], transpose3)
...         transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3))
...         return transpose4
>>>
>>> class OuterBlock(nn.Cell):
...     @lazy_inline
...     def __init__(self):
...         super(OuterBlock, self).__init__()
...         self.block = Block()
...     def construct(self, x):
...         return self.block(x)
>>>
>>> class Nets(nn.Cell):
...     def __init__(self):
...         super(Nets, self).__init__()
...         self.blocks = nn.CellList()
...         for _ in range(3):
...             b = OuterBlock()
...             b.offload()
...             self.blocks.append(b)
...     def construct(self, x):
...         out = x
...         for i in range(3):
...             out = self.blocks[i](out)
...         return out
property param_prefix

当前Cell的子Cell的参数名前缀。

样例:

>>> import mindspore as ms
>>> from mindspore import Tensor, nn
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.dense = nn.Dense(2, 2)
...
...     def construct(self, x):
...         x = self.dense(x)
...         return x
>>> net = Net()
>>> net.update_cell_prefix()
>>> print(net.dense.param_prefix)
dense
property parameter_layout_dict

parameter_layout_dict 表示一个参数的张量layout,这种张量layout是由分片策略和分布式算子信息推断出来的。

parameters_and_names(name_prefix='', expand=True)[源代码]

返回Cell中parameter的迭代器。

包含参数名称和参数本身。

参数:
  • name_prefix (str) - 作用域。默认值: ''

  • expand (bool) - 如果为True,则递归地获取当前Cell和所有子Cell的参数及名称;如果为 False ,只生成当前Cell的子Cell的参数及名称。默认值: True

返回:

迭代器,Cell的名称和Cell本身。

样例:

>>> from mindspore import nn
>>> n = nn.Dense(3, 4)
>>> names = []
>>> for m in n.parameters_and_names():
...     if m[0]:
...         names.append(m[0])
教程样例:
parameters_broadcast_dict(recurse=True)[源代码]

获取这个Cell的参数广播字典。

参数:
  • recurse (bool) - 是否包含子Cell的参数。默认值: True

返回:

OrderedDict,返回参数广播字典。

parameters_dict(recurse=True)[源代码]

获取此Cell的parameter字典。

参数:
  • recurse (bool) - 是否递归地包含所有子Cell的parameter。默认值: True

返回:

OrderedDict类型,返回参数字典。

样例:

>>> import mindspore as ms
>>> from mindspore import Tensor, nn, Parameter
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.dense = nn.Dense(2, 2)
...
...     def construct(self, x):
...         x = self.dense(x)
...         return x
>>> net = Net()
>>> print(net.parameters_dict())
OrderedDict([('dense.weight', Parameter(name=dense.weight, shape=(2, 2), dtype=Float32,
requires_grad=True)), ('dense.bias', Parameter(name=dense.bias, shape=(2,), dtype=Float32,
requires_grad=True))])
property pipeline_stage

pipeline_stage 表示当前Cell所在的stage。

place(role, rank_id)[源代码]

为该Cell中所有算子设置标签。此标签告诉MindSpore编译器此Cell在哪个进程上启动。 每个标签都由进程角色 rolerank_id 组成,因此,通过对不同Cell设置不同标签,这些Cell将在不同进程启动,使用户可以进行分布式训练/推理等任务。

说明

  • 此接口只在成功调用 mindspore.communication.init() 完成动态组网后才能生效。

参数:
  • role (str) - 算子执行所在进程的角色。只支持'MS_WORKER'。

  • rank_id (int) - 算子执行所在进程的ID。在相同进程角色间, rank_id 是唯一的。

样例:

>>> from mindspore import context
>>> import mindspore.nn as nn
>>> context.set_context(mode=context.GRAPH_MODE)
>>> fc = nn.Dense(2, 3)
>>> fc.place('MS_WORKER', 0)
recompute(**kwargs)[源代码]

设置Cell重计算。Cell中输出算子以外的所有算子将被设置为重计算。如果一个算子的计算结果被输出到一些反向节点来进行梯度计算,且被设置成重计算,那么我们会在反向传播中重新计算它,而不去存储在前向传播中的中间激活层的计算结果。

说明

  • 如果计算涉及到诸如随机化或全局变量之类的操作,那么目前还不能保证等价。

  • 如果该Cell中算子的重计算API也被调用,则该算子的重计算模式以算子的重计算API的设置为准。

  • 该接口仅配置一次,即当父Cell配置了,子Cell不需再配置。

  • Cell的输出算子默认不做重计算,这一点是基于我们减少内存占用的配置经验。如果一个Cell里面只有一个算子,且想要把这个算子设置为重计算的,那么请使用算子的重计算API。

  • 当应用了重计算且内存充足时,可以配置'mp_comm_recompute=False'来提升性能。

  • 当应用了重计算但内存不足时,可以配置'parallel_optimizer_comm_recompute=True'来节省内存。有相同融合group的Cell应该配置相同的parallel_optimizer_comm_recompute。

参数:
  • mp_comm_recompute (bool) - 表示在自动并行或半自动并行模式下,指定Cell内部由模型并行引入的通信操作是否重计算。默认值: True

  • parallel_optimizer_comm_recompute (bool) - 表示在自动并行或半自动并行模式下,指定Cell内部由优化器并行引入的AllGather通信是否重计算。默认值: False

register_backward_hook(hook_fn)[源代码]

设置Cell对象的反向hook函数。

说明

  • register_backward_hook(hook_fn) 在图模式下,或者在PyNative模式下使用 jit 装饰器功能时不起作用。

  • hook_fn必须有如下代码定义:cell 是已注册Cell对象的信息, grad_input 是Cell对象的反向输出梯度, grad_output 是反向传递给Cell对象的梯度。 用户可以在hook_fn中返回None或者返回新的梯度。

  • hook_fn返回None或者新的相应于 grad_input 的梯度:hook_fn(cell, grad_input, grad_output) -> New grad_input or None。

  • 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 construct 函数中调用 register_backward_hook(hook_fn)

  • PyNative模式下,如果在Cell对象的 construct 函数中调用 register_backward_hook(hook_fn) ,那么Cell对象每次运行都将增加一个 hook_fn

参数:
  • hook_fn (function) - 捕获Cell对象信息和反向输入,输出梯度的 hook_fn 函数。

返回:

返回与 hook_fn 函数对应的 handle 对象。可通过调用 handle.remove() 来删除添加的 hook_fn 函数。

异常:
  • TypeError - 如果 hook_fn 不是Python函数。

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Tensor, nn, ops
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
>>> def backward_hook_fn(cell, grad_input, grad_output):
...     print("backward input: ", grad_output)
...     print("backward output: ", grad_input)
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.relu = nn.ReLU()
...         self.handle = self.relu.register_backward_hook(backward_hook_fn)
...
...     def construct(self, x):
...         x = x + x
...         x = self.relu(x)
...         return x
>>> grad = ops.GradOperation(get_all=True)
>>> net = Net()
>>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)))
backward input: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)
backward output: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)
>>> print(output)
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
register_backward_pre_hook(hook_fn)[源代码]

设置Cell对象的反向pre_hook函数。

说明

  • register_backward_pre_hook(hook_fn) 在图模式下,或者在PyNative模式下使用 jit 装饰器功能时不起作用。

  • hook_fn必须有如下代码定义:cell 是已注册Cell对象的信息, grad_output 是反向传递给Cell对象的梯度。用户可以在hook_fn中返回None或者返回新的梯度。

  • hook_fn返回None或者新的相应于 grad_output 的梯度:hook_fn(cell, grad_output) -> New grad_output or None。

  • register_backward_pre_hook(hook_fn) 在PyThon环境中运行。为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 construct 函数中调用 register_backward_pre_hook(hook_fn)

  • PyNative模式下,如果在Cell对象的 construct 函数中调用 register_backward_pre_hook(hook_fn) ,那么Cell对象每次运行都将增加一个 hook_fn

参数:
  • hook_fn (function) - 捕获Cell对象信息和反向输入梯度的 hook_fn 函数。

返回:

返回与 hook_fn 函数对应的 handle 对象。可通过调用 handle.remove() 来删除添加的 hook_fn 函数。

异常:
  • TypeError - 如果 hook_fn 不是Python函数。

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Tensor, nn, ops
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
>>> def backward_pre_hook_fn(cell, grad_output):
...     print("backward input: ", grad_output)
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.relu = nn.ReLU()
...         self.handle = self.relu.register_backward_pre_hook(backward_pre_hook_fn)
...
...     def construct(self, x):
...         x = x + x
...         x = self.relu(x)
...         return x
>>> grad = ops.GradOperation(get_all=True)
>>> net = Net()
>>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)))
backward input: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)
>>> print(output)
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
register_buffer(name, tensor, persistent=True)[源代码]

在Cell添加一个缓冲区 buffer

这通常用于注册不应被视为模型参数的缓冲区。例如,BatchNorm的 running_mean 不是参数,而是Cell状态的一部分。 默认情况下,缓冲区是持久的,将与参数一起保存。可以通过将 persistent 设置为 False 来更改此行为。 持久缓冲区和非持久缓冲区之间的唯一区别是后者不会成为此Cell的 state_dict 的一部分。

可以使用指定的名称将缓冲区作为属性访问。

参数:
  • name (str) - 缓冲区的名字。可以使用给定的名称访问此Cell的缓冲区 。

  • tensor (Tensor) - 待注册的缓冲区。如果为 None ,则此Cell的 state_dict 不会包括该缓冲区。

  • persistent (bool, 可选) - 缓冲区是否是此Cell的 state_dict 的一部分。默认 True

样例:

>>> import mindspore
...
>>> class Net(mindspore.nn.Cell):
...    def __init__(self):
...        super().__init__()
...        self.register_buffer("buffer0", mindspore.tensor([1, 2, 3]))
...
...    def construct(self, x):
...        return x + self.net_buffer
...
>>> net = Net()
>>> net.register_buffer(mindspore.tensor("buffer0", [4, 5, 6]))
>>> print(net.buffer0)
[1 2 3]
register_forward_hook(hook_fn)[源代码]

设置Cell对象的正向hook函数。

说明

  • register_forward_hook(hook_fn) 在图模式下,或者在PyNative模式下使用 jit 装饰器功能时不起作用。

  • hook_fn必须有如下代码定义。 cell 是已注册Cell对象。 inputs 是网络正向传播时Cell对象的输入数据。 outputs 是网络正向传播时Cell对象的输出数据。用户可以在hook_fn中打印数据或者返回新的输出数据。

  • hook_fn返回新的输出数据或者None:hook_fn(cell, inputs, outputs) -> New outputs or None。

  • 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 construct 函数中调用 register_forward_hook(hook_fn)

  • PyNative模式下,如果在Cell对象的 construct 函数中调用 register_forward_hook(hook_fn) ,那么Cell对象每次运行都将增加一个 hook_fn

参数:
  • hook_fn (function) - 捕获Cell对象信息和正向输入,输出数据的 hook_fn 函数。

返回:

返回与 hook_fn 函数对应的 handle 对象。可通过调用 handle.remove() 来删除添加的 hook_fn 函数。

异常:
  • TypeError - 如果 hook_fn 不是Python函数。

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Tensor, nn, ops
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
>>> def forward_hook_fn(cell, inputs, output):
...     print("forward inputs: ", inputs)
...     print("forward output: ", output)
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.mul = nn.MatMul()
...         self.handle = self.mul.register_forward_hook(forward_hook_fn)
...
...     def construct(self, x, y):
...         x = x + x
...         x = self.mul(x, y)
...         return x
>>> grad = ops.GradOperation(get_all=True)
>>> net = Net()
>>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)))
forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1],
                dtype=Float32, value= [ 1.00000000e+00]))
forward output: 2.0
>>> print(output)
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
value= [ 2.00000000e+00]))
register_forward_pre_hook(hook_fn)[源代码]

设置Cell对象的正向pre_hook函数。

说明

  • register_forward_pre_hook(hook_fn) 在图模式下,或者在PyNative模式下使用 jit 装饰器功能时不起作用。

  • hook_fn必须有如下代码定义。 cell 是已注册Cell对象。 inputs 是网络正向传播时Cell对象的输入数据。用户可以在hook_fn中打印输入数据或者返回新的输入数据。

  • hook_fn返回新的输入数据或者None:hook_fn(cell, inputs) -> New inputs or None。

  • 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 construct 函数中调用 register_forward_pre_hook(hook_fn)

  • PyNative模式下,如果在Cell对象的 construct 函数中调用 register_forward_pre_hook(hook_fn) ,那么Cell对象每次运行都将增加一个 hook_fn

参数:
  • hook_fn (function) - 捕获Cell对象信息和正向输入数据的hook_fn函数。

返回:

返回与 hook_fn 函数对应的 handle 对象。可通过调用 handle.remove() 来删除添加的 hook_fn 函数。

异常:
  • TypeError - 如果 hook_fn 不是Python函数。

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Tensor, nn, ops
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
>>> def forward_pre_hook_fn(cell, inputs):
...     print("forward inputs: ", inputs)
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.mul = nn.MatMul()
...         self.handle = self.mul.register_forward_pre_hook(forward_pre_hook_fn)
...
...     def construct(self, x, y):
...         x = x + x
...         x = self.mul(x, y)
...         return x
>>> grad = ops.GradOperation(get_all=True)
>>> net = Net()
>>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)))
forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1],
                dtype=Float32, value= [ 1.00000000e+00]))
>>> print(output)
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
value= [ 2.00000000e+00]))
register_load_state_dict_post_hook(hook)[源代码]

mindspore.nn.Cell.load_state_dict() 方法注册一个后钩子。

它应该具有以下签名:

hook(cell, incompatible_keys) -> None

参数 cell 是此钩子注册的当前cell,参数 incompatible_keys 是一个 NamedTuple ,由属性 missing_keysunexpected_keys 组成。missing_keys 是包含缺失键的 list , 而 unexpected_keys 是包含意外键的 list

请注意,正如预期的那样,在使用 strict=True 调用:func: load_state_dict 时执行的检查会受到钩子对 missing_keysunexpected_keys 所做修改的影响。 当 strict=True 时,添加任何一组键都会导致抛出错误,而清除缺失和意外的键将避免错误。

参数:
  • hook (Callable) - 在调用load_state_dict之后执行的钩子。

返回:

RemovableHandle,一个句柄,可以通过调用 handle.remove() 来移除已添加的钩子。

register_load_state_dict_pre_hook(hook)[源代码]

mindspore.nn.Cell.load_state_dict() 方法注册一个预钩子。

它应该具有以下签名:

hook(cell, state_dict, prefix, local_metadata, strict, missing_keys, expected_keys, error_msgs) -> None

注册的钩子可以就地修改 state_dict

参数:
  • hook (Callable) - 在调用load_state_dict之前执行的钩子。

返回:

RemovableHandle,一个句柄,可以通过调用 handle.remove() 来移除已添加的钩子。

register_state_dict_post_hook(hook)[源代码]

mindspore.nn.Cell.state_dict() 方法注册一个后钩子。

它应该具有以下签名:

hook(cell, state_dict, prefix, local_metadata) -> None

注册的钩子可用于在调用 state_dict 之后执行后处理。

参数:
  • hook (Callable) - 在调用state_dict之后执行的钩子。

返回:

RemovableHandle,一个句柄,可以通过调用 handle.remove() 来移除已添加的钩子。

register_state_dict_pre_hook(hook)[源代码]

mindspore.nn.Cell.state_dict() 方法注册一个预钩子。

它应该具有以下签名:

hook(cell, prefix, keep_vars) -> None

注册的钩子可用于在调用 state_dict 之前执行预处理。

参数:
  • hook (Callable) - 在调用state_dict之前执行的钩子。

返回:

RemovableHandle,一个句柄,可以通过调用 handle.remove() 来移除已添加的钩子。

样例:

Examples: >>> import mindspore … … >>> class NetA(mindspore.nn.Cell): … def __init__(self): … super().__init__() … self.buffer_a = mindspore.nn.Buffer(mindspore.tensor([1, 2, 3])) … self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3])) … … def construct(self, x): … return x + self.buffer_a + self.param_a … … >>> def _add_extra_param(cell, prefix, keep_vars): … cell._params["extra_param"] = mindspore.Parameter(mindspore.tensor([4, 5, 6])) … … >>> net = NetA() >>> handle = net.register_state_dict_pre_hook(_add_extra_param) >>> net_state_dict = net.state_dict() >>> handle.remove() >>> print("extra_param" in net_state_dict) True

remove_redundant_parameters()[源代码]

删除冗余参数。

这个接口通常不需要显式调用。

run_construct(cast_inputs, kwargs)[源代码]

运行construct方法。

说明

该函数已经弃用,将会在未来版本中删除。不推荐使用此函数。

参数:
  • cast_inputs (tuple) - Cell的输入。

  • kwargs (dict) - 关键字参数。

返回:

Cell的输出。

set_boost(boost_type)[源代码]

为了提升网络性能,可以配置boost内的算法让框架自动使能该算法来加速网络训练。

请确保 boost_type 所选择的算法在 algorithm library 算法库中。

说明

部分加速算法可能影响网络精度,请谨慎选择。

参数:
  • boost_type (str) - 加速算法。

返回:

Cell类型,Cell本身。

异常:
  • ValueError - 如果 boost_type 不在boost算法库内。

set_broadcast_flag(mode=True)[源代码]

设置该Cell的参数广播模式。

参数:
  • mode (bool) - 指定当前模式是否进行参数广播。默认值: True

set_comm_fusion(fusion_type, recurse=True)[源代码]

为Cell中的参数设置融合类型。请参考 mindspore.Parameter.comm_fusion 的描述。

说明

当函数被多次调用时,此属性值将被重写。

参数:
  • fusion_type (int) - Parameter的 comm_fusion 属性的设置值。

  • recurse (bool) - 是否递归地设置子Cell的可训练参数。默认值: True

set_data_parallel()[源代码]

在非自动策略搜索的情况下,如果此Cell的所有算子(包括此Cell内含嵌套的cell)未指定并行策略,则将为这些基本算子设置为数据并行策略。

说明

仅在图模式,使用auto_parallel_context = ParallelMode.AUTO_PARALLEL生效。

样例:

>>> import mindspore.nn as nn
>>> net = nn.Dense(3, 4)
>>> net.set_data_parallel()
set_extra_state(state)[源代码]

设置加载的 state_dict 中包含的额外状态。

此方法由 load_state_dict 调用,以处理 state_dict 中的任何额外状态。 如果您的 Cell 需要在 state_dict 中存储额外状态,请实现此方法及相应的 get_extra_state 方法。

参数:
  • state (dict) - state_dict 的额外状态。

set_grad(requires_grad=True)[源代码]

Cell的梯度设置。

参数:
  • requires_grad (bool) - 指定网络是否需要梯度,如果为 True ,PyNative模式下Cell将构建反向网络。默认值: True

返回:

Cell类型,Cell本身。

set_inputs(*inputs, **kwargs)[源代码]

设置编译计算图所需的输入。输入数量需与数据集数量一致。若使用Model接口,请确保所有传入Model的网络和损失函数都配置了set_inputs。 输入Tensor的shape可以为动态或静态。

说明

有两种配置模式:

  • 全量配置模式:输入将被用作图编译时的完整编译参数。

  • 增量配置模式:输入被配置到Cell的部分输入上,这些输入将替换图编译对应位置上的参数。

只能传入inputs和kwargs的其中一个。inputs用于全量配置模式,kwargs用于增量配置模式。

参数:
  • inputs (tuple) - 全量配置模式的参数。

  • kwargs (dict) - 增量配置模式的参数。可设置的key值为 self.construct 中定义的参数名。

警告

这是一个实验性API,后续可能修改或删除。

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import nn, Tensor
>>>
>>> class ReluNet(nn.Cell):
...     def __init__(self):
...         super(ReluNet, self).__init__()
...         self.relu = nn.ReLU()
...     def construct(self, x):
...         return self.relu(x)
>>>
>>> net = ReluNet()
>>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
>>> net.set_inputs(input_dyn)
>>> input = Tensor(np.random.random([3, 10]), dtype=ms.float32)
>>> output = net(input)
>>>
>>> net2 = ReluNet()
>>> net2.set_inputs(x=input_dyn)
>>> output = net2(input)
set_jit_config(jit_config)[源代码]

为Cell设置编译时所使用的JitConfig配置项。

参数:

样例:

>>> import mindspore as ms
>>> from mindspore import Tensor, nn
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.relu = nn.ReLU()
...
...     def construct(self, x):
...         x = self.relu(x)
...         return x
>>> net = Net()
>>> jitconfig = ms.JitConfig()
>>> net.set_jit_config(jitconfig)
set_param_ps(recurse=True, init_in_server=False)[源代码]

设置可训练参数是否由参数服务器更新,以及是否在服务器上初始化可训练参数。

说明

只在运行的任务处于参数服务器模式时有效。 只支持在图模式下调用。

参数:
  • recurse (bool) - 是否设置子网络的可训练参数。默认值: True

  • init_in_server (bool) - 是否在服务器上初始化由参数服务器更新的可训练参数。默认值: False

set_train(mode=True)[源代码]

将Cell设置为训练模式。

设置当前Cell和所有子Cell的训练模式。对于训练和预测具有不同结构的网络层(如 BatchNorm),将通过这个属性区分分支。如果设置为True,则执行训练分支,否则执行另一个分支。

说明

当执行 mindspore.train.Model.train() 的时候,框架会默认调用Cell.set_train(True)。 当执行 mindspore.train.Model.eval() 的时候,框架会默认调用Cell.set_train(False)。

参数:
  • mode (bool) - 指定模型是否为训练模式。默认值: True

返回:

Cell类型,Cell本身。

教程样例:
shard(in_strategy, out_strategy=None, parameter_plan=None, device='Ascend', level=0)[源代码]

指定输入/输出Tensor的分布策略,通过其余算子的策略推导得到。在PyNative模式下,可以利用此方法指定某个Cell以图模式进行分布式执行。 在图模式下, 可以利用此方法设置某个模块的分布式切分策略,未设置的会自动通过策略传播方式配置。 in_strategy/out_strategy需要为元组类型, 其中的每一个元素指定对应的输入/输出的Tensor分布策略,可参考: mindspore.ops.Primitive.shard() 的描述。 其余算子的并行策略由输入输出指定的策略推导得到。

说明

调用该方法后,并行模式(parallel_mode)会自动设置为"auto_parallel"且搜索模式(search_mode)自动设置为"sharding_propagation"。 如果输入含有Parameter,其对应的策略应该在 in_strategy 里设置。

参数:
  • in_strategy (tuple) - 指定各输入的切分策略,输入元组的每个元素元组,元组即具体指定输入每一维的切分策略。

  • out_strategy (Union[None, tuple]) - 指定各输出的切分策略,用法同in_strategy,目前未使能。默认值: None

  • parameter_plan (Union[dict, None]) - 指定各参数的切分策略,传入字典时,键是str类型的参数名,值是一维整数tuple表示相应的切分策略, 如果参数名错误或对应参数已经设置了切分策略,该参数的设置会被跳过。默认值: None

  • device (str) - 指定执行设备,可以为[ "CPU" , "GPU" , "Ascend" ]中任意一个,目前未使能。默认值: "Ascend"

  • level (int) - 指定搜索切分策略的目标函数,即是最大化计算通信比、最小化内存消耗、最大化执行速度等。可以为[ 0 , 1 , 2 ]中任意一个,默认值: 0 。目前仅支持最大化计算通信比,其余模式未使能。

返回:

Function,返回一个在自动并行流程下执行的函数。

样例:

>>> import mindspore.nn as nn
>>>
>>> class Block(nn.Cell):
...   def __init__(self):
...     self.dense1 = nn.Dense(10, 10)
...     self.relu = nn.ReLU()
...     self.dense2 = nn.Dense2(10, 10)
...   def construct(self, x):
...     x = self.relu(self.dense2(self.relu(self.dense1(x))))
...     return x
>>>
>>> class example(nn.Cell):
...   def __init__(self):
...     self.block1 = Block()
...     self.block2 = Block()
...     self.block2_shard = self.block2.shard(in_strategy=((2, 1),),
...                                           parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
...   def construct(self, x):
...     x = self.block1(x)
...     x = self.block2_shard(x)
...     return x
state_dict(*args, destination=None, prefix='', keep_vars=False)[源代码]

返回一个包含对Cell整个状态的引用的字典。

参数和持久缓冲区(例如运行平均值)都包括在内。键是相应的参数和缓冲区名称。设置为 None 的参数和缓冲区不包括在内。

说明

返回的对象是一个浅拷贝。它包含对该Cell的参数和缓冲区的引用。

警告

  • 目前 state_dict() 还按顺序接受 destinationprefixkeep_vars 的位置参数。但是这即将被弃用,关键字参数将在未来的版本中强制执行。

  • 请避免使用参数 destination ,因为它不是为最终用户设计的。

参数:
  • destination (dict,可选) - 如果提供,Cell的状态将更新到此字典中,并返回相同的对象。否则,将创建并返回 OrderedDict 。默认 None

  • prefix (bool,可选) - 添加到参数和缓冲区名称的前缀,用于组成state_dict中的键。默认 ""

  • keep_vars (bool,可选) - 状态字典返回值是否为拷贝。默认 False ,返回引用。

返回:

Dict,包含整个Cell状态的字典。

样例:

>>> import mindspore
>>> class Model(mindspore.nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.buffer_a = mindspore.nn.Buffer(mindspore.tensor([4, 5, 6]))
...         self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3]))
...
...     def construct(self, x):
...         return x + self.buffer_a + self.param_a
...
...
>>> model = Model()
>>> print(model.state_dict())
OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
to_float(dst_type)[源代码]

在Cell和所有子Cell的输入上添加类型转换,以使用特定的浮点类型运行。

如果 dst_typemindspore.dtype.float16 ,Cell的所有输入(包括作为常量的input、Parameter、Tensor)都会被转换为float16。请参考 mindspore.amp.build_train_network() 的源代码中的用法。

说明

多次调用将产生覆盖。

参数:
  • dst_type (mindspore.dtype) - Cell转换为 dst_type 类型运行。 dst_type 可以是 mindspore.dtype.float16mindspore.dtype.float32 或者 mindspore.dtype.bfloat16

返回:

Cell类型,Cell本身。

异常:
  • ValueError - 如果 dst_type 不是 mindspore.dtype.float32 ,不是 mindspore.dtype.float16 , 也不是 mindspore.dtype.bfloat16

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore.nn as nn
>>> from mindspore import dtype as mstype
>>>
>>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
>>> net.to_float(mstype.float16)
Conv2d(input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same,
padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW)
trainable_params(recurse=True)[源代码]

返回Cell的一个可训练参数的列表。

参数:
  • recurse (bool) - 是否递归地包含当前Cell的所有子Cell的可训练参数。默认值: True

返回:

List类型,可训练参数列表。

教程样例:
untrainable_params(recurse=True)[源代码]

返回Cell的一个不可训练参数的列表。

参数:
  • recurse (bool) - 是否递归地包含当前Cell的所有子Cell的不可训练参数。默认值: True

返回:

List类型,不可训练参数列表。

update_cell_prefix()[源代码]

递归地更新所有子Cell的 param_prefix

在调用此方法后,可以通过Cell的 param_prefix 属性获取该Cell的所有子Cell的名称前缀。

update_cell_type(cell_type)[源代码]

量化感知训练网络场景下,更新当前Cell的类型。

此方法将Cell类型设置为 cell_type

参数:
  • cell_type (str) - 被更新的类型,cell_type 可以是"quant"或"second-order"。

update_parameters_name(prefix='', recurse=True)[源代码]

给网络参数名称添加 prefix 前缀字符串。

参数:
  • prefix (str) - 前缀字符串。默认值: ''

  • recurse (bool) - 是否递归地包含所有子Cell的参数。默认值: True