mindspore.nn.Cell
- class mindspore.nn.Cell(auto_prefix=True, flags=None)[source]
The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this base class.
Layers in mindspore.nn are also the subclass of Cell, such as
mindspore.nn.Conv2d
, andmindspore.nn.ReLU
, etc. Cell will be compiled into a calculation graph in GRAPH_MODE (static graph mode) and used as the basic module of neural networks in PYNATIVE_MODE (dynamic graph mode).Note
Cell is the inference mode by default. For a class that inherits a Cell, if the training and inference have different structures, the subclass performs the inference branch by default. To set the training mode, refer to
mindspore.nn.Cell.set_train()
.Warning
In the subclass of Cell, it's not allowed to define a method named 'cast' and not allowed to define an attribute named 'phase' or 'cells', otherwise, an error will be raised.
- Parameters
auto_prefix (bool, optional) – Whether to automatically generate NameSpace for Cell and its child cells. It also affects the names of parameters in the Cell. If set to
True
, the parameter name will be automatically prefixed, otherwise not. In general, the backbone network should be set toTrue
, otherwise the duplicate name problem will appear. The cell to train the backbone network, such as optimizer andmindspore.nn.TrainOneStepCell
, should be set toFalse
, otherwise the parameter name in backbone will be changed by mistake. Default:True
.flags (dict, optional) – Network configuration information, currently it is used for the binding of network and dataset. Users can also customize network attributes by this parameter. Default:
None
.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore.nn as nn >>> from mindspore import ops >>> class MyCell(nn.Cell): ... def __init__(self, forward_net): ... super(MyCell, self).__init__(auto_prefix=False) ... self.net = forward_net ... self.relu = ops.ReLU() ... ... def construct(self, x): ... y = self.net(x) ... return self.relu(y) >>> >>> inner_net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal') >>> my_net = MyCell(inner_net) >>> print(my_net.trainable_params()) ... # If the 'auto_prefix' set to True or not set when call the '__init__' method of the parent class, ... # the parameter's name will be 'net.weight'. [Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]
- add_flags(**flags)[source]
Add customized attributes for cell.
This method is also called when the cell class is instantiated and the class parameter 'flags' is set to True.
- Parameters
flags (dict) – Network configuration information, currently it is used for the binding of network and dataset. Users can also customize network attributes by this parameter.
Examples
>>> import mindspore as ms >>> from mindspore import Tensor, nn ... >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.relu = nn.ReLU() ... ... def construct(self, x): ... x = self.relu(x) ... return x >>> net = Net() >>> net.add_flags(sink_mode=True) >>> print(net.sink_mode) True
- add_flags_recursive(**flags)[source]
If a cell contains child cells, this method can recursively customize attributes of all cells.
- Parameters
flags (dict) – Network configuration information, currently it is used for the binding of network and dataset. Users can also customize network attributes by this parameter.
Examples
>>> import mindspore as ms >>> from mindspore import Tensor, nn ... >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.relu = nn.ReLU() ... ... def construct(self, x): ... x = self.relu(x) ... return x >>> net = Net() >>> net.add_flags_recursive(sink_mode=True) >>> print(net.sink_mode) True
- apply(fn)[source]
Applies fn recursively to every subcell (as returned by .cells()) as well as self. Typical use includes initializing the parameters of a model.
- Parameters
fn (function) – function to be applied to each subcell.
- Returns
Cell, self.
Examples
>>> import mindspore.nn as nn >>> from mindspore.common.initializer import initializer, One >>> net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2)) >>> def func(cell): ... if isinstance(cell, nn.Dense): ... cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype)) >>> net.apply(func) SequentialCell( (0): Dense(input_channels=2, output_channels=2, has_bias=True) (1): Dense(input_channels=2, output_channels=2, has_bias=True) ) >>> print(net[0].weight.asnumpy()) [[1. 1.] [1. 1.]]
- auto_cast_inputs(inputs)[source]
Auto cast inputs in mixed precision scenarios.
- Parameters
inputs (tuple) – the inputs of construct.
- Returns
Tuple, the inputs after data type cast.
- property bprop_debug
Get whether cell custom bprop debug is enabled.
- buffers(recurse: bool = True)[source]
Return an iterator over cell buffers.
- Parameters
recurse (bool) – If
True
, then yields buffers of this cell and all sub cells. Otherwise, yields only buffers that are direct members of this cell. DefaultTrue
.- Returns
Iterator[Tensor], an iterator of buffer.
Examples
>>> import mindspore ... ... >>> class NetB(mindspore.nn.Cell): ... def __init__(self): ... super().__init__() ... self.buffer_b = mindspore.nn.Buffer(mindspore.tensor([1, 2, 3])) ... ... def construct(self, x): ... return x + self.buffer_b ... ... >>> class NetA(mindspore.nn.Cell): ... def __init__(self, net_b): ... super().__init__() ... self.net_b = net_b ... self.buffer_a = mindspore.nn.Buffer(mindspore.tensor([4, 5, 6])) ... ... def construct(self, x): ... return self.net_b(x) + self.buffer_a ... ... >>> net_b = NetB() >>> net_a = NetA(net_b) >>> >>> for buffer in net_a.buffers(): >>> print(f'buffer is {buffer}') buffer is [4, 5, 6] buffer is [1, 2, 3]
- cast_inputs(inputs, dst_type)[source]
Cast inputs to specified type.
- Parameters
dst_type (mindspore.dtype) – The specified data type.
- Returns
tuple[Tensor], the result with destination data type.
- cast_param(param)[source]
Cast parameter according to auto mix precision level in pynative mode.
This interface is currently used in the case of auto mix precision and usually needs not to be used explicitly.
- Parameters
param (Parameter) – Parameters, the type of which should be cast.
- Returns
Parameter, the input parameter with type automatically cast.
- cells()[source]
Returns an iterator over immediate cells.
- Returns
Iteration, the immediate cells in the cell.
Examples
>>> import mindspore as ms >>> from mindspore import Tensor, nn ... >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.dense = nn.Dense(2, 2) ... ... def construct(self, x): ... x = self.dense(x) ... return x >>> net = Net() >>> print(net.cells()) odict_values([Dense(input_channels=2, output_channels=2, has_bias=True)])
- cells_and_names(cells=None, name_prefix='')[source]
Returns an iterator over all cells in the network, including the cell's name and itself.
- Parameters
- Returns
Iteration, all the child cells and corresponding names in the cell.
Examples
>>> from mindspore import nn >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.conv = nn.Conv2d(3, 64, 3) ... def construct(self, x): ... out = self.conv(x) ... return out >>> names = [] >>> n = Net() >>> for m in n.cells_and_names(): ... if m[0]: ... names.append(m[0])
- compile(*args, **kwargs)[source]
Compile Cell as a computation graph, the input must be consistent with the input defined in construct.
- compile_and_run(*args, **kwargs)[source]
Compile and run Cell, the input must be consistent with the input defined in construct.
Note
It is not recommended to call directly.
- construct(*args, **kwargs)[source]
Defines the computation to be performed. This method must be overridden by all subclasses.
Note
It is not supported currently that inputs contain both tuple and non-tuple types at same time.
- property exist_names
Get exist parameter names adding by tuple or list of parameter.
- extend_repr()[source]
Expand the description of Cell.
To print customized extended information, re-implement this method in your own cells.
- flatten_weights(fusion_size=0)[source]
Reset data for weight parameters so that they are using contiguous memory chunks grouped by data type.
Note
By default, parameters with same data type will using a single contiguous memory chunk. but for some models with huge number of parameters, splitting a large memory chunk into several smaller memory chunks has the potential for performance gains, if this is the case, we can use 'fusion_size' to limit the maximum memory chunk size.
- Parameters
fusion_size (int) – Maximum memory chunk size in bytes,
0
for unlimited. Default:0
.
- get_buffer(target: str)[source]
Return the buffer given by target if it exists, otherwise throw an error.
See the docstring for get_sub_cell for a more detailed explanation of this method's functionality as well as how to correctly specify target .
- Parameters
target (str) – The fully-qualified string name of the buffer to look for. (See get_sub_cell for how to specify a fully-qualified string.)
- Returns
Tensor
Examples
>>> import mindspore ... ... >>> class NetC(mindspore.nn.Cell): ... def __init__(self): ... super().__init__() ... self.buffer_c = mindspore.nn.Buffer(mindspore.tensor([0, 0, 0])) ... ... def construct(self, x): ... return x + self.buffer_c ... ... >>> class NetB(mindspore.nn.Cell): ... def __init__(self, net_c): ... super().__init__() ... self.net_c = net_c ... self.buffer_b = mindspore.nn.Buffer(mindspore.tensor([1, 2, 3])) ... ... def construct(self, x): ... return self.net_c(x) + self.buffer_b ... ... >>> class NetA(mindspore.nn.Cell): ... def __init__(self, net_b): ... super().__init__() ... self.net_b = net_b ... self.buffer_a = mindspore.nn.Buffer(mindspore.tensor([4, 5, 6])) ... ... def construct(self, x): ... return self.net_b(x) + self.buffer_a ... ... >>> net_c = NetC() >>> net_b = NetB(net_c) >>> net_a = NetA(net_b) >>> buffer_c = net_a.get_buffer("net_b.net_c.buffer_c") >>> print(f'buffer_c is {buffer_c}') buffer_c is [0, 0, 0]
- get_extra_state()[source]
Return any extra state to include in the cell's state_dict.
This function is called from
state_dict
. Implement this and a correspondingset_extra_state
for your cell if you need to store extra state.Note that extra state should be picklable to ensure working serialization of the state_dict. Only provide backwards compatibility guarantees for serializing tensors; other objects may break backwards compatibility if their serialized pickled form changes.
- Returns
object, any extra state to store in the cell's state_dict.
- get_flags()[source]
Get the self_defined attributes of the cell, which can be added by add_flags method.
Examples
>>> import mindspore as ms >>> from mindspore import Tensor, nn ... >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.relu = nn.ReLU() ... ... def construct(self, x): ... x = self.relu(x) ... return x >>> net = Net() >>> net.add_flags(sink_mode=True) >>> print(net.get_flags()) {'sink_mode':True}
- get_inputs()[source]
Returns the dynamic_inputs of a cell object in one network.
- Returns
inputs (tuple), Inputs of the Cell object.
Warning
This is an experimental API that is subject to change or deletion.
Examples
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import nn, Tensor >>> >>> class ReluNet(nn.Cell): ... def __init__(self): ... super(ReluNet, self).__init__() ... self.relu = nn.ReLU() ... def construct(self, x): ... return self.relu(x) >>> >>> net = ReluNet() >>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32) >>> net.set_inputs(input_dyn) >>> get_inputs = net.get_inputs() >>> print(get_inputs) (Tensor(shape=[3, -1], dtype=Float32, value= ),)
- get_parameters(expand=True)[source]
Returns an iterator over cell parameters.
Yields parameters of this cell. If expand is
true
, yield parameters of this cell and all subcells. For more details about subcells, please see the example below.- Parameters
expand (bool) – If
true
, yields parameters of this cell and all subcells. Otherwise, only yield parameters that are direct members of this cell. Default:True
.- Returns
Iteration, all parameters at the cell.
Examples
>>> import mindspore as ms >>> from mindspore import nn, ops, Tensor >>> import numpy as np >>> class TestNet(nn.Cell): ... def __init__(self): ... super().__init__() ... self.my_w1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32)) ... self.my_w2 = ms.Parameter(Tensor(np.ones([16]), ms.float32)) ... def construct(self, x): ... x += self.my_w1 ... x = ops.reshape(x, (16,)) - self.my_w2 ... return x >>> class TestNet2(nn.Cell): ... def __init__(self): ... super().__init__() ... self.my_t1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32)) ... # self.subcell is a subcell of TestNet2, when using expand=True, the parameters of TestNet will ... # also be gathered. ... self.subcell = TestNet() ... def construct(self, x): ... x += self.my_w1 ... x = ops.reshape(x, (16,)) - self.my_w2 ... return x >>> net = TestNet2() >>> print([p for p in net.get_parameters(expand=True)]) [Parameter (name=my_t1, shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w1, shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w2, shape=(16,), dtype=Float32, requires_grad=True)]
- get_scope()[source]
Returns the scope of a cell object in one network.
- Returns
String, scope of the cell.
- get_sub_cell(target: str)[source]
Return the sub cell given by target if it exists, otherwise throw an error.
For example, let's say you have an
nn.Cell
A
that looks like this:A( (net_b): NetB( (net_c): NetC( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (dense): Dense(in_features=100, out_features=200, bias=True) ) )
(The diagram shows an
nn.Cell
A
.A
has a nested sub cellnet_b
, which itself has two sub cellsnet_c
anddense
.net_c
then has a sub cellconv
.)To check whether we have the
dense
sub cell, we would call get_sub_cell("net_b.dense"). To check whether we have theconv
sub cell, we would call get_sub_cell("net_b.net_c.conv").The runtime of
get_sub_cell
is bounded by the degree of cell nesting in target. A query against name_cells achieves the same result, but it is O(N) in the number of transitive cells. So, for a simple check to see if some sub cells exist,get_sub_cell
should always be used.- Parameters
target (str) – The fully-qualified string name of the sub cell to look for. (See above example for how to specify a fully-qualified string.)
- Returns
Cell
Examples
>>> import mindspore ... ... >>> class NetC(mindspore.nn.Cell): ... def __init__(self): ... super().__init__() ... self.buffer_c = mindspore.nn.Buffer(mindspore.tensor([0, 0, 0])) ... self.dense_c = mindspore.nn.Dense(5, 3) ... ... def construct(self, x): ... return self.dense_c(x) + self.buffer_c ... ... >>> class NetB(mindspore.nn.Cell): ... def __init__(self, net_c): ... super().__init__() ... self.net_c = net_c ... self.buffer_b = mindspore.nn.Buffer(mindspore.tensor([1, 2, 3])) ... ... def construct(self, x): ... return self.net_c(x) + self.buffer_b ... ... >>> class NetA(mindspore.nn.Cell): ... def __init__(self, net_b): ... super().__init__() ... self.net_b = net_b ... self.buffer_a = mindspore.nn.Buffer(mindspore.tensor([4, 5, 6])) ... ... def construct(self, x): ... return self.net_b(x) + self.buffer_a ... ... >>> net_c = NetC() >>> net_b = NetB(net_c) >>> net_a = NetA(net_b) >>> net_c = net_a.get_sub_cell("net_b.net_c") >>> print(f'net_c is {net_c}') net_c is NetC( (dense_c): Dense(input_channels=5, output_channels=3, has_bias=True) )
- infer_param_pipeline_stage()[source]
Infer pipeline stages of all parameters in the cell.
Note
The interface is deprecated from version 2.3 and will be removed in a future version.
- Returns
The params belong to current stage in pipeline parallel.
- Raises
RuntimeError – If there is a parameter does not belong to any stage.
- init_parameters_data(auto_parallel_mode=False)[source]
Initialize all parameters and replace the original saved parameters in cell.
Note
trainable_params() and other similar interfaces may return different parameter instance after init_parameters_data. It is not recommended to save these results.
- Parameters
auto_parallel_mode (bool) – If running in auto_parallel_mode. Default:
False
.- Returns
Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter.
Examples
>>> import mindspore as ms >>> from mindspore import Tensor, nn ... >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.dense = nn.Dense(2, 2) ... ... def construct(self, x): ... x = self.dense(x) ... return x >>> net = Net() >>> print(net.init_parameters_data()) {Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True): Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True), Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True): Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True)}
- insert_child_to_cell(child_name, child_cell)[source]
Adds a child cell to the current cell with a given name.
- Parameters
- Raises
Examples
>>> import mindspore as ms >>> from mindspore import Tensor, nn ... >>> net1 = nn.ReLU() >>> net2 = nn.Dense(2, 2) >>> net1.insert_child_to_cell("child", net2) >>> print(net1) ReLU( (child): Dense(input_channels=2, output_channels=2, has_bias=True) )
- insert_param_to_cell(param_name, param, check_name_contain_dot=True)[source]
Adds a parameter to the current cell.
Inserts a parameter with given name to the cell. The method is currently used in mindspore.nn.Cell.__setattr__.
- Parameters
- Raises
Examples
>>> import mindspore as ms >>> from mindspore import Tensor, nn, Parameter ... >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.relu = nn.ReLU() ... ... def construct(self, x): ... x = self.relu(x) ... return x >>> net = Net() >>> net.insert_param_to_cell("bias", Parameter(Tensor([1, 2, 3]))) >>> print(net.bias) Parameter(name=bias, shape=(3,), dtype=Int64, requires_grad=True)
- load_state_dict(state_dict: Mapping[str, Any], strict: bool = True)[source]
Copy parameters and buffers from
state_dict
into this cell and its descendants.If
strict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this cell'smindspore.nn.Cell.state_dict()
function.- Parameters
state_dict (dict) – A dict containing parameters and persistent buffers.
strict (bool, optional) – Whether to strictly enforce that the keys in input state_dict match the keys returned by this cell's
mindspore.nn.Cell.state_dict()
function. DefaultTrue
.
- Returns
A namedtuple with
missing_keys
andunexpected_keys
fields,- missing_keys is a list of str containing any keys that are expected
by this cell but missing from the provided
state_dict
.
- unexpected_keys is a list of str containing the keys that are not
expected by this cell but present in the provided
state_dict
.
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,mindspore.nn.Cell.load_state_dict()
will raise aRuntimeError
.Examples
>>> import mindspore >>> import os >>> class Model(mindspore.nn.Cell): ... def __init__(self): ... super().__init__() ... self.buffer_a = mindspore.nn.Buffer(mindspore.tensor([4, 5, 6])) ... self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3])) ... ... def construct(self, x): ... return x + self.buffer_a + self.param_a ... ... >>> model = Model() >>> print(model.state_dict()) >>> mindspore.save_checkpoint(model.state_dict(), './model_state_dict_ckpt') >>> new_model = Model() >>> new_model.load_state_dict(mindspore.load_checkpoint('./model_state_dict_ckpt')) >>> print(new_model.state_dict()) >>> os.remove('./model_state_dict_ckpt') OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \ ('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))]) OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \ ('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
- name_cells()[source]
Returns an iterator over all immediate cells in the network.
Include name of the cell and cell itself.
- Returns
Dict, all the child cells and corresponding names in the cell.
Examples
>>> import mindspore as ms >>> from mindspore import Tensor, nn ... >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.dense = nn.Dense(2, 2) ... ... def construct(self, x): ... x = self.dense(x) ... return x >>> net = Net() >>> print(net.name_cells()) OrderedDict([('dense', Dense(input_channels=2, output_channels=2, has_bias=True))])
- named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True)[source]
Return an iterator over cell buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters
prefix (str, optional) – prefix to prepend to all buffer names. Default
""
.recurse (bool, optional) – if
True
, then yields buffers of this cell and all sub cells. Otherwise, yields only buffers that are direct members of this cell. DefaultTrue
.remove_duplicate (bool, optional) – Whether to remove the duplicated buffers in the result. Default
True
.
- Returns
Iterator[Tuple[str, Tensor]], an iterator of tuple containing the name and buffer.
Examples
>>> import mindspore ... ... >>> class NetB(mindspore.nn.Cell): ... def __init__(self): ... super().__init__() ... self.buffer_b = mindspore.nn.Buffer(mindspore.tensor([1, 2, 3])) ... ... def construct(self, x): ... return x + self.buffer_b ... ... >>> class NetA(mindspore.nn.Cell): ... def __init__(self, net_b): ... super().__init__() ... self.net_b = net_b ... self.buffer_a = mindspore.nn.Buffer(mindspore.tensor([4, 5, 6])) ... ... def construct(self, x): ... return self.net_b(x) + self.buffer_a ... ... >>> net_b = NetB() >>> net_a = NetA(net_b) >>> >>> for name, buffer in net_a.named_buffers(): >>> print(f'buffer name is {name}, buffer is {buffer}') buffer name is buffer_a, buffer is [4, 5, 6] buffer name is net_b.buffer_b, buffer is [1, 2, 3]
- offload(backward_prefetch='Auto')[source]
Set the cell offload. All primitive ops in the cell will be set offload. For the intermediate activations calculated by these primitive ops, we will not save them in the forward pass, but offload them and onload them in the backward pass.
Note
If Cell.offload is called, the mode should be set to "GRAPH_MODE".
If Cell.offload is called, lazyinline should be enabled.
- Parameters
backward_prefetch (Union[str, int], optional) – The timing for prefetching activations in advance in backward pass. Default:
"Auto"
. If set it to"Auto"
, framework will start to prefetch activations one operator in advance. If set it to a positive int value, framework will start to prefetch activationsbackward_prefetch
operators in advance, such as 1, 20, 100.
Examples
>>> import mindspore.nn as nn >>> from mindspore import ops >>> from mindspore.common import Tensor, Parameter >>> >>> class Block(nn.Cell): ... def __init__(self): ... super(Block, self).__init__() ... self.transpose1 = ops.Transpose() ... self.transpose2 = ops.Transpose() ... self.transpose3 = ops.Transpose() ... self.transpose4 = ops.Transpose() ... self.real_div1 = ops.RealDiv() ... self.real_div2 = ops.RealDiv() ... self.batch_matmul1 = ops.BatchMatMul() ... self.batch_matmul2 = ops.BatchMatMul() ... self.softmax = ops.Softmax(-1) ... self.expand_dims = ops.ExpandDims() ... self.sub = ops.Sub() ... self.y = Parameter(Tensor(np.ones((1024, 128, 128)).astype(np.float32))) ... def construct(self, x): ... transpose1 = self.transpose1(x, (0, 2, 1, 3)) ... real_div1 = self.real_div1(transpose1, Tensor(2.37891)) ... transpose2 = self.transpose2(x, (0, 2, 3, 1)) ... real_div2 = self.real_div2(transpose2, Tensor(2.37891)) ... batch_matmul1 = self.batch_matmul1(real_div1, real_div2) ... expand_dims = self.expand_dims(self.y, 1) ... sub = self.sub(Tensor([1.0]), expand_dims) ... soft_max = self.softmax(sub) ... transpose3 = self.transpose3(x, (0, 2, 1, 3)) ... batch_matmul2 = self.batch_matmul2(soft_max[0], transpose3) ... transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3)) ... return transpose4 >>> >>> class OuterBlock(nn.Cell): ... @lazy_inline ... def __init__(self): ... super(OuterBlock, self).__init__() ... self.block = Block() ... def construct(self, x): ... return self.block(x) >>> >>> class Nets(nn.Cell): ... def __init__(self): ... super(Nets, self).__init__() ... self.blocks = nn.CellList() ... for _ in range(3): ... b = OuterBlock() ... b.offload() ... self.blocks.append(b) ... def construct(self, x): ... out = x ... for i in range(3): ... out = self.blocks[i](out) ... return out
- property param_prefix
Param prefix is the prefix of current cell's direct child parameter.
Examples
>>> import mindspore as ms >>> from mindspore import Tensor, nn ... >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.dense = nn.Dense(2, 2) ... ... def construct(self, x): ... x = self.dense(x) ... return x >>> net = Net() >>> net.update_cell_prefix() >>> print(net.dense.param_prefix) dense
- property parameter_layout_dict
parameter_layout_dict represents the tensor layout of a parameter, which is inferred by shard strategy and distributed operator information.
- parameters_and_names(name_prefix='', expand=True)[source]
Returns an iterator over cell parameters.
Includes the parameter's name and itself.
- Parameters
- Returns
Iteration, all the names and corresponding parameters in the cell.
Examples
>>> from mindspore import nn >>> n = nn.Dense(3, 4) >>> names = [] >>> for m in n.parameters_and_names(): ... if m[0]: ... names.append(m[0])
- Tutorial Examples:
- parameters_broadcast_dict(recurse=True)[source]
Gets the parameters broadcast dictionary of this cell.
- Parameters
recurse (bool) – Whether contains the parameters of subcells. Default:
True
.- Returns
OrderedDict, return parameters broadcast dictionary.
- parameters_dict(recurse=True)[source]
Gets the parameters dictionary of this cell.
- Parameters
recurse (bool) – Whether contains the parameters of subcells. Default:
True
.- Returns
OrderedDict, return parameters dictionary.
Examples
>>> import mindspore as ms >>> from mindspore import Tensor, nn, Parameter ... >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.dense = nn.Dense(2, 2) ... ... def construct(self, x): ... x = self.dense(x) ... return x >>> net = Net() >>> print(net.parameters_dict()) OrderedDict([('dense.weight', Parameter(name=dense.weight, shape=(2, 2), dtype=Float32, requires_grad=True)), ('dense.bias', Parameter(name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True))])
- property pipeline_stage
pipeline_stage represents the pipeline stage of current Cell.
- place(role, rank_id)[source]
Set the label for all operators in this cell. This label tells MindSpore compiler on which process this cell should be launched. And each process's identical label consists of input role and rank_id. So by setting different cells with different labels, which will be launched on different processes, users can launch a distributed training or predicting job.
Note
This method is effective only after mindspore.communication.init() is called for dynamic cluster building.
- Parameters
Examples
>>> from mindspore import context >>> import mindspore.nn as nn >>> context.set_context(mode=context.GRAPH_MODE) >>> fc = nn.Dense(2, 3) >>> fc.place('MS_WORKER', 0)
- recompute(**kwargs)[source]
Set the cell recomputed. All the primitive in the cell except the outputs will be set recomputed. If a primitive set recomputed feeds into some backward nodes for computing gradient, rather than storing the intermediate activation computed in forward pass, we will recompute it in backward pass.
Note
If the computation involves something like randomization or global variable, the equivalence is not guaranteed currently.
If the recompute api of a primitive in this cell is also called, the recompute mode of this primitive is subject to the recompute api of the primitive.
The interface can be configured only once. Therefore, when the parent cell is configured, the child cell should not be configured.
The outputs of cell are excluded from recomputation by default, which is based on our configuration experience to reduce memory footprint. If a cell has only one primitive and the primitive is wanted to be set recomputed, use the recompute api of the primtive.
When the memory remains after applying the recomputation, configuring 'mp_comm_recompute=False' to improve performance if necessary.
When the memory still not enough after applying the recompute, configuring 'parallel_optimizer_comm_recompute=True' to save more memory if necessary. Cells in the same fusion group should have the same parallel_optimizer_comm_recompute configures.
- Parameters
mp_comm_recompute (bool) – Specifies whether the model parallel communication operators in the cell are recomputed in auto parallel or semi auto parallel mode. Default:
True
.parallel_optimizer_comm_recompute (bool) – Specifies whether the communication operator allgathers introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode. Default:
False
.
- register_backward_hook(hook_fn)[source]
Register the backward hook function.
Note
The register_backward_hook(hook_fn) does not work in graph mode or functions decorated with 'jit'.
The 'hook_fn' must be defined as the following code. cell is the registered Cell object. grad_input is the gradient computed and passed to the next Cell or primitive, which can be return a new gradient or None. grad_output is the gradient passed to the Cell.
The 'hook_fn' should have the following signature: hook_fn(cell, grad_input, grad_output) -> New grad_input gradient or none.
The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to graph mode, it is not recommended to write it in the construct function of Cell object. In the pynative mode, if the register_backward_hook function is called in the construct function of the Cell object, a hook function will be added at each run time of Cell object.
- Parameters
hook_fn (function) – Python function. Backward hook function.
- Returns
A handle corresponding to the hook_fn . The handle can be used to remove the added hook_fn by calling handle.remove() .
- Raises
TypeError – If the hook_fn is not a function of python.
Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import Tensor, nn, ops >>> ms.set_context(mode=ms.PYNATIVE_MODE) >>> def backward_hook_fn(cell, grad_input, grad_output): ... print("backward input: ", grad_output) ... print("backward output: ", grad_input) ... >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.relu = nn.ReLU() ... self.handle = self.relu.register_backward_hook(backward_hook_fn) ... ... def construct(self, x): ... x = x + x ... x = self.relu(x) ... return x >>> grad = ops.GradOperation(get_all=True) >>> net = Net() >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32))) backward input: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),) backward output: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),) >>> print(output) (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
- register_backward_pre_hook(hook_fn)[source]
Register the backward pre hook function.
Note
The register_backward_pre_hook(hook_fn) does not work in graph mode or functions decorated with 'jit'.
The 'hook_fn' must be defined as the following code. cell is the Cell object. grad_output is the gradient passed to the Cell.
The 'hook_fn' should have the following signature: hook_fn(cell, grad_output) -> New grad_output gradient or None.
The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to graph mode, it is not recommended to write it in the construct function of Cell object.
In the pynative mode, if the register_backward_pre_hook function is called in the construct function of the Cell object, a hook function will be added at each run time of Cell object.
- Parameters
hook_fn (function) – Python function. Backward pre hook function.
- Returns
A handle corresponding to the hook_fn . The handle can be used to remove the added hook_fn by calling handle.remove() .
- Raises
TypeError – If the hook_fn is not a function of python.
Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import Tensor, nn, ops >>> ms.set_context(mode=ms.PYNATIVE_MODE) >>> def backward_pre_hook_fn(cell, grad_output): ... print("backward input: ", grad_output) ... >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.relu = nn.ReLU() ... self.handle = self.relu.register_backward_pre_hook(backward_pre_hook_fn) ... ... def construct(self, x): ... x = x + x ... x = self.relu(x) ... return x >>> grad = ops.GradOperation(get_all=True) >>> net = Net() >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32))) backward input: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),) >>> print(output) (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
- register_buffer(name: str, tensor: Optional[Tensor], persistent: bool = True)[source]
Add a buffer to the cell.
This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the cell's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting persistent to
False
. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this cell'sstate_dict
.Buffers can be accessed as attributes using given names.
- Parameters
name (str) – name of the buffer. The buffer can be accessed from this cell using the given name.
tensor (Tensor) – Buffer to be registered. If
None
, the buffer is not included in the cell'sstate_dict
.persistent (bool, optional) – Whether the buffer is part of this cell's
state_dict
. DefaultTrue
.
Examples
>>> import mindspore ... >>> class Net(mindspore.nn.Cell): ... def __init__(self): ... super().__init__() ... self.register_buffer("buffer0", mindspore.tensor([1, 2, 3])) ... ... def construct(self, x): ... return x + self.net_buffer ... >>> net = Net() >>> net.register_buffer(mindspore.tensor("buffer0", [4, 5, 6])) >>> print(net.buffer0) [1 2 3]
- register_forward_hook(hook_fn)[source]
Set the Cell forward hook function.
Note
The register_forward_hook(hook_fn) does not work in graph mode or functions decorated with 'jit'.
'hook_fn' must be defined as the following code. cell is the object of registered Cell. inputs is the forward input objects passed to the Cell. output is the forward output object of the Cell. The 'hook_fn' can modify the forward output object by returning new forward output object.
It should have the following signature: hook_fn(cell, inputs, output) -> new output object or none.
In order to prevent running failed when switching to graph mode, it is not recommended to write it in the construct function of Cell object. In the pynative mode, if the register_forward_hook function is called in the construct function of the Cell object, a hook function will be added at each run time of Cell object.
- Parameters
hook_fn (function) – Python function. Forward hook function.
- Returns
A handle corresponding to the hook_fn . The handle can be used to remove the added hook_fn by calling handle.remove() .
- Raises
TypeError – If the hook_fn is not a function of python.
Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import Tensor, nn, ops >>> ms.set_context(mode=ms.PYNATIVE_MODE) >>> def forward_hook_fn(cell, inputs, output): ... print("forward inputs: ", inputs) ... print("forward output: ", output) ... >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.mul = nn.MatMul() ... self.handle = self.mul.register_forward_hook(forward_hook_fn) ... ... def construct(self, x, y): ... x = x + x ... x = self.mul(x, y) ... return x >>> grad = ops.GradOperation(get_all=True) >>> net = Net() >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32))) forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00])) forward output: 2.0 >>> print(output) (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]))
- register_forward_pre_hook(hook_fn)[source]
Register forward pre hook function for Cell object.
Note
The register_forward_pre_hook(hook_fn) does not work in graph mode or functions decorated with 'jit'.
'hook_fn' must be defined as the following code. cell is the object of registered Cell. inputs is the forward input objects passed to the Cell. The 'hook_fn' can modify the forward input objects by returning new forward input objects.
It should have the following signature: hook_fn(cell, inputs) -> new input objects or none.
In order to prevent running failed when switching to graph mode, it is not recommended to write it in the construct function of Cell object. In the pynative mode, if the register_forward_pre_hook function is called in the construct function of the Cell object, a hook function will be added at each run time of Cell object.
- Parameters
hook_fn (function) – Python function. Forward pre hook function.
- Returns
A handle corresponding to the hook_fn . The handle can be used to remove the added hook_fn by calling handle.remove() .
- Raises
TypeError – If the hook_fn is not a function of python.
Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import Tensor, nn, ops >>> ms.set_context(mode=ms.PYNATIVE_MODE) >>> def forward_pre_hook_fn(cell, inputs): ... print("forward inputs: ", inputs) ... >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.mul = nn.MatMul() ... self.handle = self.mul.register_forward_pre_hook(forward_pre_hook_fn) ... ... def construct(self, x, y): ... x = x + x ... x = self.mul(x, y) ... return x >>> grad = ops.GradOperation(get_all=True) >>> net = Net() >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32))) forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00])) >>> print(output) (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]))
- register_load_state_dict_post_hook(hook)[source]
Register a post-hook to be run after cell's
mindspore.nn.Cell.load_state_dict()
is called.It should have the following signature:
hook(cell, incompatible_keys) -> None
The
cell
argument is the current cell that this hook is registered on, and theincompatible_keys
argument is aNamedTuple
consisting of attributesmissing_keys
andunexpected_keys
.missing_keys
is alist
ofstr
containing the missing keys andunexpected_keys
is alist
ofstr
containing the unexpected keys.The given incompatible_keys can be modified inplace if needed.
Note that the checks performed when calling
load_state_dict()
withstrict=True
are affected by modifications the hook makes tomissing_keys
orunexpected_keys
, as expected. Additions to either set of keys will result in an error being thrown whenstrict=True
, and clearing out both missing and unexpected keys will avoid an error.- Parameters
hook (Callable) – The hook function after load_state_dict is called.
- Returns
RemovableHandle
, a handle that can be used to remove the added hook by calling handle.remove().
- register_load_state_dict_pre_hook(hook)[source]
Register a pre-hook to be run before cell's
mindspore.nn.Cell.load_state_dict()
is called.It should have the following signature:
hook(cell, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950
- Parameters
hook (Callable) – The hook function before load_state_dict is called.
- Returns
RemovableHandle
, a handle that can be used to remove the added hook by calling handle.remove().
- register_state_dict_post_hook(hook)[source]
Register a post-hook for the
mindspore.nn.Cell.state_dict()
method.It should have the following signature:
hook(cell, state_dict, prefix, local_metadata) -> None
The registered hooks can modify the
state_dict
inplace.- Parameters
hook (Callable) – The hook function after state_dict is called.
- Returns
RemovableHandle
, a handle that can be used to remove the added hook by calling handle.remove().
- register_state_dict_pre_hook(hook)[source]
Register a pre-hook for the
mindspore.nn.Cell.state_dict()
method.It should have the following signature:
hook(cell, prefix, keep_vars) -> None
The registered hooks can be used to perform pre-processing before the state_dict call is made.
- Parameters
hook (Callable) – The hook function before state_dict is called.
- Returns
RemovableHandle
, a handle that can be used to remove the added hook by calling handle.remove().
Examples:
Examples
>>> import mindspore ... ... >>> class NetA(mindspore.nn.Cell): ... def __init__(self): ... super().__init__() ... self.buffer_a = mindspore.nn.Buffer(mindspore.tensor([1, 2, 3])) ... self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3])) ... ... def construct(self, x): ... return x + self.buffer_a + self.param_a ... ... >>> def _add_extra_param(cell, prefix, keep_vars): ... cell._params["extra_param"] = mindspore.Parameter(mindspore.tensor([4, 5, 6])) ... ... >>> net = NetA() >>> handle = net.register_state_dict_pre_hook(_add_extra_param) >>> net_state_dict = net.state_dict() >>> handle.remove() >>> print("extra_param" in net_state_dict) True
- remove_redundant_parameters()[source]
Remove the redundant parameters.
This interface usually needs not to be used explicitly.
- run_construct(cast_inputs, kwargs)[source]
Run the construct function.
Note
This function will be removed in a future version. It is not recommended to call this function.
- set_boost(boost_type)[source]
In order to improve the network performance, configure the network auto enable to accelerate the algorithm in the algorithm library.
If boost_type is not in the algorithm library, please view the algorithm in the algorithm library through algorithm library.
Note
Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
- Parameters
boost_type (str) – accelerate algorithm.
- Returns
Cell, the cell itself.
- Raises
ValueError – If boost_type is not in the algorithm library.
- set_broadcast_flag(mode=True)[source]
Set parameter broadcast mode for this cell.
- Parameters
mode (bool) – Specifies whether the mode is parameter broadcast. Default:
True
.
- set_comm_fusion(fusion_type, recurse=True)[source]
Set comm_fusion for all the parameters in this cell. Please refer to the description of
mindspore.Parameter.comm_fusion
.Note
The value of attribute will be overwritten when the function is called multiply.
- set_data_parallel()[source]
For all primitive ops in this cell(including ops of cells that wrapped by this cell), if parallel strategy is not specified, then instead of auto-searching, data parallel strategy will be generated for those primitive ops.
Note
Only effective while using auto_parallel_context = ParallelMode.AUTO_PARALLEL under graph mode.
Examples
>>> import mindspore.nn as nn >>> net = nn.Dense(3, 4) >>> net.set_data_parallel()
- set_extra_state(state: Any)[source]
Set extra state contained in the loaded state_dict.
This function is called from load_state_dict to handle any extra state found within the state_dict. Implement this function and a corresponding get_extra_state for your cell if you need to store extra state within its state_dict.
- Parameters
state (dict) – Extra state from the state_dict.
- set_grad(requires_grad=True)[source]
Sets the cell flag for gradient.
- Parameters
requires_grad (bool) – Specifies if the net need to grad, if it is
true
, the cell will construct backward network in pynative mode. Default:True
.- Returns
Cell, the cell itself.
- set_inputs(*inputs, **kwargs)[source]
Save set inputs for computation graph. The number of inputs should be the same with that of the datasets. When using Model for dynamic shape, please make sure that all networks and loss functions passed to the Model are configured with set_inputs. The shape of input Tensor can be either dynamic or static.
Note
There are two mode:
Full mode: arguments will be used as all compile inputs for graph-compiling.
Incremental mode: arguments will set to some of the Cell inputs, which will be substituted into the input at the corresponding position for graph-compiling.
Only one of inputs or kwargs can be set. Inputs for full mode and kwargs for incremental mode.
- Parameters
Warning
This is an experimental API that is subject to change or deletion.
Examples
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import nn, Tensor >>> >>> class ReluNet(nn.Cell): ... def __init__(self): ... super(ReluNet, self).__init__() ... self.relu = nn.ReLU() ... def construct(self, x): ... return self.relu(x) >>> >>> net = ReluNet() >>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32) >>> net.set_inputs(input_dyn) >>> input = Tensor(np.random.random([3, 10]), dtype=ms.float32) >>> output = net(input) >>> >>> net2 = ReluNet() >>> net2.set_inputs(x=input_dyn) >>> output = net2(input)
- set_jit_config(jit_config)[source]
Set jit config for cell.
- Parameters
jit_config (JitConfig) – Jit config for compile. For details, please refer to
mindspore.JitConfig
.
Examples
>>> import mindspore as ms >>> from mindspore import Tensor, nn ... >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.relu = nn.ReLU() ... ... def construct(self, x): ... x = self.relu(x) ... return x >>> net = Net() >>> jitconfig = ms.JitConfig() >>> net.set_jit_config(jitconfig)
- set_param_ps(recurse=True, init_in_server=False)[source]
Set whether the trainable parameters are updated by parameter server and whether the trainable parameters are initialized on server.
Note
It only works when a running task is in the parameter server mode. It is only supported in graph mode.
- set_train(mode=True)[source]
Sets the cell to training mode.
The cell itself and all children cells will be set to training mode. Layers that have different constructions for training and predicting, such as BatchNorm, will distinguish between the branches by this attribute. If set to true, the training branch will be executed, otherwise another branch.
Note
When execute function Model.train(), framework will call Cell.set_train(True). When execute function Model.eval(), framework will call Cell.set_train(False).
- Parameters
mode (bool) – Specifies whether the model is training. Default:
True
.- Returns
Cell, the cell itself.
- Tutorial Examples:
- shard(in_strategy, out_strategy=None, parameter_plan=None, device='Ascend', level=0)[source]
Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be generated by sharding propagation. In PyNative mode, use this method to specify a Cell for distributed execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell, strategy for others will be set by sharding propagation. in_strategy and out_strategy define the input and output layout respectively. in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of this input/output, which can refer to the description of
mindspore.ops.Primitive.shard()
. The parallel strategies of remaining operators are derived from the strategy specified by the input and output.Note
If Cell.shard is called, the parallel mode in set_auto_parallel_context (parallel_mode) will be set to "auto_parallel" and the search mode (search_mode) to "sharding_propagation". If the input contain Parameter, its strategy should be set in in_strategy.
- Parameters
in_strategy (tuple) – Define the layout of inputs, each element of the tuple should be a tuple. Tuple defines the layout of the corresponding input.
out_strategy (Union[None, tuple]) – Define the layout of outputs similar with in_strategy. It is not in use right now. Default:
None
.parameter_plan (Union[dict, None]) – Define the layout for the specified parameters. Each element in dict defines the layout of the parameter like "param_name: layout". The key is a parameter name of type 'str'. The value is a 1-D integer tuple, indicating the corresponding layout. If the parameter name is incorrect or the corresponding parameter has been set, the parameter setting will be ignored. Default:
None
.device (str) – Select a certain device target. It is not in use right now. Support [
"CPU"
,"GPU"
,"Ascend"
]. Default:"Ascend"
.level (int) – Option for parallel strategy infer algorithm, namely the object function, maximize computation over communication ratio, maximize speed performance, minimize memory usage etc. It is not in use right now. Support [
"0"
,"1"
,"2"
]. Default:0
.
- Returns
Function, return the cell construct function that will be executed under auto parallel process.
Examples
>>> import mindspore.nn as nn >>> >>> class Block(nn.Cell): ... def __init__(self): ... self.dense1 = nn.Dense(10, 10) ... self.relu = nn.ReLU() ... self.dense2 = nn.Dense2(10, 10) ... def construct(self, x): ... x = self.relu(self.dense2(self.relu(self.dense1(x)))) ... return x >>> >>> class example(nn.Cell): ... def __init__(self): ... self.block1 = Block() ... self.block2 = Block() ... self.block2_shard = self.block2.shard(in_strategy=((2, 1),), ... parameter_plan={'self.block2.shard.dense1.weight': (4, 1)}) ... def construct(self, x): ... x = self.block1(x) ... x = self.block2_shard(x) ... return x
- state_dict(*args, destination=None, prefix='', keep_vars=False)[source]
Return a dictionary containing references to the whole state of the cell.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Note
The returned object is a shallow copy. It contains references to the cell's parameters and buffers.
Warning
- Currently
state_dict()
also accepts positional arguments for destination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.
- Currently
- Please avoid the use of argument
destination
as it is not designed for end-users.
- Please avoid the use of argument
- Parameters
destination (dict, optional) – If provided, the state of cell will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – A prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – Whether the state_dict returns a copy. Default:
False
, returns a reference.
- Returns
Dict, a dictionary containing a whole state of the cell.
Examples
>>> import mindspore >>> class Model(mindspore.nn.Cell): ... def __init__(self): ... super().__init__() ... self.buffer_a = mindspore.nn.Buffer(mindspore.tensor([4, 5, 6])) ... self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3])) ... ... def construct(self, x): ... return x + self.buffer_a + self.param_a ... ... >>> model = Model() >>> print(model.state_dict()) OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \ ('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
- to_float(dst_type)[source]
Add cast on all inputs of cell and child cells to run with certain float type.
If dst_type is mindspore.dtype.float16, all the inputs of Cell, including input, Parameter and Tensor, will be cast to float16. Please refer to the usage in source code of
mindspore.amp.build_train_network()
.Note
Multiple calls will overwrite.
- Parameters
dst_type (
mindspore.dtype
) – Transfer cell to run with dst_type. dst_type can be mstype.float16 , mstype.float32 or mstype.bfloat16.- Returns
Cell, the cell itself.
- Raises
ValueError – If dst_type is not mstype.float32 , mstype.float16 or mstype.bfloat16.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore.nn as nn >>> from mindspore import dtype as mstype >>> >>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal') >>> net.to_float(mstype.float16) Conv2d(input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW)
- trainable_params(recurse=True)[source]
Returns all trainable parameters.
Returns a list of all trainable parameters.
- Parameters
recurse (bool) – Whether contains the trainable parameters of subcells. Default:
True
.- Returns
List, the list of trainable parameters.
- Tutorial Examples:
- untrainable_params(recurse=True)[source]
Returns all untrainable parameters.
Returns a list of all untrainable parameters.
- Parameters
recurse (bool) – Whether contains the untrainable parameters of subcells. Default:
True
.- Returns
List, the list of untrainable parameters.
- update_cell_prefix()[source]
Update the param_prefix of all child cells.
After being invoked, it can get all the cell's children's name prefix by '_param_prefix'.