mindspore.nn.Cell
- class mindspore.nn.Cell(auto_prefix=True, flags=None)[source]
Base class for all neural networks.
A ‘Cell’ could be a single neural network cell, such as conv2d, relu, batch_norm, etc. or a composition of cells to constructing a network.
Note
In general, the autograd algorithm will automatically generate the implementation of the gradient function, but if back-propagation(bprop) method is implemented, the gradient function will be replaced by the bprop. The bprop implementation will receive a Tensor dout containing the gradient of the loss w.r.t. the output, and a Tensor out containing the forward result. The bprop needs to compute the gradient of the loss w.r.t. the inputs, gradient of the loss w.r.t. Parameter variables are not supported currently. The bprop method must contain the self parameter.
- Parameters
auto_prefix (bool) – Recursively generate namespaces. Default: True.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> class MyCell(nn.Cell): ... def __init__(self): ... super(MyCell, self).__init__() ... self.relu = P.ReLU() ... ... def construct(self, x): ... return self.relu(x)
- property bprop_debug
Get whether cell custom bprop debug is enabled.
- cast_param(param)[source]
Cast parameter according to auto mix precision level in pynative mode.
- Parameters
param (Parameter) – The parameter to cast.
- cells_and_names(cells=None, name_prefix='')[source]
Returns an iterator over all cells in the network.
Includes the cell’s name and itself.
- Parameters
Examples
>>> n = Net() >>> names = [] >>> for m in n.cells_and_names(): ... if m[0]: ... names.append(m[0])
- compile_and_run(*inputs)[source]
Compiles and runs cell.
- Parameters
inputs (tuple) – Input parameters.
- Returns
Object, the result of executing.
- construct(*inputs, **kwargs)[source]
Defines the computation to be performed. This method must be overridden by all subclasses.
- Returns
Tensor, returns the computed result.
- extend_repr()[source]
Sets the extended representation of the Cell.
To print customized extended information, re-implement this method in your own cells.
- get_parameters(expand=True)[source]
Returns an iterator over cell parameters.
Yields parameters of this cell. If expand is True, yield parameters of this cell and all subcells.
- Parameters
expand (bool) – If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters that are direct members of this cell. Default: True.
Examples
>>> net = Net() >>> parameters = [] >>> for item in net.get_parameters(): ... parameters.append(item)
- init_parameters_data(auto_parallel_mode=False)[source]
Initialize all parameters and replace the original saved parameters in cell.
Note
trainable_params() and other similar interfaces may return different parameter instance after init_parameters_data, do not save these result.
- Parameters
auto_parallel_mode (bool) – If running in auto_parallel_mode.
- Returns
Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter.
- insert_child_to_cell(child_name, child_cell)[source]
Adds a child cell to the current cell with a given name.
- insert_param_to_cell(param_name, param, check_name=True)[source]
Adds a parameter to the current cell.
Inserts a parameter with given name to the cell. Please refer to the usage in source code of mindspore.nn.Cell.__setattr__.
- Parameters
- Raises
KeyError – If the name of parameter is null or contains dot.
AttributeError – If user did not call init() first.
TypeError – If the type of parameter is not Parameter.
- load_parameter_slice(params)[source]
Replace parameters with sliced tensors by parallel strategies.
Please refer to the usage in source code of mindspore.common._Executor.compile.
- Parameters
params (dict) – The parameters dictionary used for initializing the data graph.
- name_cells()[source]
Returns an iterator over all cells in the network.
Include name of the cell and cell itself.
- property param_prefix
Param prefix is the prefix of current cell’s direct child parameter.
- parameters_and_names(name_prefix='', expand=True)[source]
Returns an iterator over cell parameters.
Includes the parameter’s name and itself.
- Parameters
Examples
>>> n = Net() >>> names = [] >>> for m in n.parameters_and_names(): ... if m[0]: ... names.append(m[0])
- parameters_dict(recurse=True)[source]
Gets parameters dictionary.
Gets the parameters dictionary of this cell.
- Parameters
recurse (bool) – Whether contains the parameters of subcells. Default: True.
- Returns
OrderedDict, return parameters dictionary.
- recompute(mode=True)[source]
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive set recomputed feeds into some backward nodes for computing gradient, rather than storing the intermediate activation computed in forward pass, we will recompute it in backward pass.
Note
If the computation involves something like randomization or global variable, the equivalence is not guaranteed currently.
If the recompute api of a primitive in this cell is also called, the recompute mode of this primitive is subject to the recompute api of the primitive.
- Parameters
mode (bool) – Specifies whether the cell is recomputed. Default: True.
- register_backward_hook(fn)[source]
Set the cell backward hook function. Note that this function is only supported in Pynative Mode.
Note
fn must be defined as the following code. cell_name is the name of registered cell. grad_input is gradient passed to the cell. grad_output is the gradient computed and passed to the next cell or primitive, which may be modified and returned. hook_fn(cell_name, grad_input, grad_output) -> Tensor or None.
- Parameters
fn (function) – Specifies the hook function with grad as input.
- set_auto_parallel()[source]
Set the cell to auto parallel mode.
Note
If a cell needs to use the auto parallel or semi auto parallel mode for training, evaluation or prediction, this interface needs to be called by the cell.
- set_broadcast_flag(mode=True)[source]
Set the cell to data_parallel mode.
The cell can be accessed as an attribute using the given name.
- Parameters
mode (bool) – Specifies whether the model is data_parallel. Default: True.
- set_comm_fusion(fusion_type, recurse=True)[source]
Set comm_fusion for all the parameters in the Net. Please refer to the description of mindspore.common.parameter.comm_fusion.
Note
The value of attribute will be overwritten when the function is called multiply.
- set_grad(requires_grad=True)[source]
Sets the cell flag for gradient. In pynative mode, this parameter specifies whether the network require gradients. If True, the backward network needed to compute the gradients will be generated when the forward network is executed.
- Parameters
requires_grad (bool) – Specifies if the net need to grad, if it is True, cell will construct backward network in pynative mode. Default: True.
- set_parallel_input_with_inputs(*inputs)[source]
Slice inputs tensors by parallel strategies, and set the sliced inputs to _parallel_input_run
- Parameters
inputs (tuple) – inputs of construct method.
- set_param_ps(recurse=True, init_in_server=False)[source]
Set whether the trainable parameters are updated by parameter server and whether the trainable parameters are initialized on server.
Note
It only works when a running task is in the parameter server mode.
- set_train(mode=True)[source]
Sets the cell to training mode.
The cell itself and all children cells will be set to training mode. Layers that have different constructions for training and predicting, such as BatchNorm, will distinguish between the branches by this attribute. If set to True, the training branch will be executed, otherwise another branch.
- Parameters
mode (bool) – Specifies whether the model is training. Default: True.
- to_float(dst_type)[source]
Add cast on all inputs of cell and child cells to run with certain float type.
If dst_type is mindspore.dtype.float16, all the inputs of Cell including input, Parameter, Tensor as const will be cast to float16. Please refer to the usage in source code of mindspore.train.amp.build_train_network.
Note
Multiple calls will overwrite.
- Parameters
dst_type (
mindspore.dtype
) – Transfer Cell to Run with dst_type. dst_type can be mindspore.dtype.float16 or mindspore.dtype.float32.- Raises
ValueError – If dst_type is not float32 nor float16.
- trainable_params(recurse=True)[source]
Returns all trainable parameters.
Returns a list of all trainable parameters.
- Parameters
recurse (bool) – Whether contains the trainable parameters of subcells. Default: True.
- Returns
List, the list of trainable parameters.
- untrainable_params(recurse=True)[source]
Returns all untrainable parameters.
Returns a list of all untrainable parameters.
- Parameters
recurse (bool) – Whether contains the untrainable parameters of subcells. Default: True.
- Returns
List, the list of untrainable parameters.
- update_cell_prefix()[source]
Update the all child cells’ self.param_prefix.
After being invoked, it can get all the cell’s children’s name prefix by ‘_param_prefix’.