Source code for mindspore.nn.cell

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""cell"""
import time
import gc
from collections import OrderedDict
from mindspore import log as logger
from .. import context
from ..common import dtype as mstype
from ..common.api import _executor, _pynative_exec
from .._checkparam import _check_str_by_regular
from ..common.parameter import Parameter, ParameterTuple
from .._c_expression import init_backend
from ..ops.primitive import Primitive
from ..ops.operations import HookBackward
from ..parallel._tensor import _load_tensor_by_layout
from ..common.tensor import Tensor


[docs]class Cell: """ Base class for all neural network. A 'Cell' could be a single neural network cell, such as conv2d, relu, batch_norm, etc. or a composition of cells to constructing a network. Note: In general, the autograd algorithm will automatically generate the implementation of the gradient function, but if bprop method is implemented, the gradient function will be replaced by the bprop. The bprop implementation will receive a Tensor `dout` containing the gradient of the loss w.r.t. the output, and a Tensor `out` containing the forward result. The bprop need to compute the gradient of the loss w.r.t. the inputs, gradient of the loss w.r.t. Parameter variables is not supported currently. Args: auto_prefix (bool): Recursively generate namespaces. Default: True. Examples: >>> class MyCell(Cell): >>> def __init__(self): >>> super(MyCell, self).__init__() >>> self.relu = P.ReLU() >>> >>> def construct(self, x): >>> return self.relu(x) """ def __init__(self, auto_prefix=True, flags=None): self._params = OrderedDict() self._cells = OrderedDict() self.training = False self.requires_grad = False self.pynative = False self._param_prefix = '' self._auto_prefix = auto_prefix self._scope = None self._phase = 'train' self._parameter_layout_dict = {} self._create_time = int(time.time() * 1e9) init_backend() # call gc to release GE session resources used by non-used cell objects gc.collect() self._construct_inputs_num = 0 self._construct_inputs_names = [] self._auto_parallel_mode = False self._parallel_inputs_run = None if flags: self.add_flags(**flags) self._backward_hook = None self.enable_hook = False self._bprop_debug = False self._is_run = False self.cell_type = None @property def is_run(self): return self._is_run @is_run.setter def is_run(self, value): self._is_run = value @property def create_time(self): return self._create_time @property def cell_init_args(self): return self._cell_init_args @property def param_prefix(self): """ Param prefix is the prefix of current cell's direct child parameter. """ return self._param_prefix @property def bprop_debug(self): """ Get whether cell custom bprop debug is enabled. """ return self._bprop_debug @bprop_debug.setter def bprop_debug(self, value): """ Set whether to enable cell custom bprop debug. Note: When bprop is defined in cell, the bprop function will be executed in python interpreter when bprop debug is true, and will be parsed and add to graph when bprop debug is false. Args: value (bool): Specifies whether to enable bprop debug. Default: False. """ if not isinstance(value, bool): raise TypeError("'bprop debug' value must be bool type.") self._bprop_debug = value
[docs] def update_cell_prefix(self): """ Update the all child cells' self.param_prefix. After invoked, can get all the cell's children's name prefix by '_param_prefix'. """ cells_name = self.cells_and_names() for cell_name, cell in cells_name: cell._param_prefix = cell_name
[docs] def update_cell_type(self, cell_type): """ Update current cell type mainly identify if quantization aware training network. After invoked, can set the cell type to 'cell_type'. """ self.cell_type = cell_type
@cell_init_args.setter def cell_init_args(self, value): if not isinstance(value, str): raise TypeError("'cell_init_args' must be string type.") self._cell_init_args = value @property def phase(self): return self._phase @phase.setter def phase(self, value): if not isinstance(value, str): raise TypeError("'phase' must be string type.") self._phase = value @property def parameter_layout_dict(self): return self._parameter_layout_dict @property def cls_name(self): return self.__class__.__name__ @parameter_layout_dict.setter def parameter_layout_dict(self, value): if not isinstance(value, dict): raise TypeError("'parameter_layout_dict' must be dict type.") self._parameter_layout_dict = value
[docs] def get_func_graph_proto(self): """Return graph binary proto.""" return _executor._get_func_graph_proto(self.phase + "." + str(self.create_time), "anf_ir", True)
def __getattr__(self, name): if '_params' in self.__dict__: params = self.__dict__['_params'] if name in params: return params[name] if '_cells' in self.__dict__: cells = self.__dict__['_cells'] if name in cells: return cells[name] raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name)) def __del__(self): _pynative_exec.clear("resource") if hasattr(self, "_create_time"): _executor.del_net_res(str(self._create_time)) def __delattr__(self, name): if name in self._params: del self._params[name] elif name in self._cells: del self._cells[name] else: object.__delattr__(self, name) def __call__(self, *inputs): if context.get_context("mode") == context.GRAPH_MODE: out = self.compile_and_run(*inputs) return out self.init_parameters_data() orign_grad = [] if self.requires_grad is True: _pynative_exec.set_grad_flag(True) _pynative_exec.new_graph(self, *inputs) for cell in self.cells(): orign_grad.append(cell.requires_grad) cell.set_grad(True) else: _pynative_exec.set_grad_flag(False) if self.enable_hook: output = self._hook_construct(*inputs) else: output = self.construct(*inputs) if isinstance(output, Parameter): output = output.data if self.requires_grad is True: _pynative_exec.end_graph(self, output, *inputs) for i, cell in enumerate(self.cells()): cell.set_grad(orign_grad[i]) self._is_run = True return output def __setattr__(self, name, value): cells = self.__dict__.get('_cells') params = self.__dict__.get('_params') if isinstance(value, Parameter): if params is None: raise AttributeError("Can not assign params before Cell.__init__() call.") if name in self.__dict__: if self.__dict__[name] is not None: raise TypeError("Expected type is not in (Parameter, Cell), but got Parameter.") del self.__dict__[name] if cells and name in cells: raise TypeError("Expected type is Cell, but got Parameter.") self.insert_param_to_cell(name, value) elif isinstance(value, ParameterTuple): if params is None: raise AttributeError("Can not assign params before Cell.__init__() call.") for item in value: self.insert_param_to_cell(item.name, item, check_name=False) object.__setattr__(self, name, value) elif isinstance(value, Cell): if cells is None: raise AttributeError("Can not assign cells before Cell.__init__() call.") if name in self.__dict__: del self.__dict__[name] if params and name in params: raise TypeError("Expected type is Parameter, but got Cell.") if self._auto_prefix: value.update_parameters_name(name + '.') cells[name] = value elif params and name in params: if isinstance(value, Tensor) and self._params[name] is not None: self._params[name].set_parameter_data(value) elif value is not None: raise TypeError("Expected type in (Parameter, ParameterTuple), but got {}.".format(type(value))) else: self.insert_param_to_cell(name, None) elif cells and name in cells: if value is not None: raise TypeError("Expected type is cell, but got {}.".format(type(value))) self._cells[name] = None else: if isinstance(value, Primitive): value.set_prim_instance_name(name) object.__setattr__(self, name, value)
[docs] def extend_repr(self): """ Sets the extended representation of the Cell. To print customized extended information, re-implement this method in your own cells. """ return ''
def __repr__(self): extra_str = self.extend_repr() info_str = self.__class__.__name__ + '<' if self._cells: sub_str = '\n' if extra_str: sub_str += '{}\n'.format(self.extend_repr()) for key, value in self._cells.items(): sub_str += '({}): {}\n'.format(key, repr(value)) sub_str = sub_str.replace('\n', '\n ') + '>' info_str += sub_str else: info_str += extra_str + '>' return info_str
[docs] def load_parameter_slice(self, params): """ Replace parameters with sliced tensors by parallel strategies. Please refer to the usage in source code of `mindspore.common._Executor.compile`. Args: params (dict): The parameters dictionary used for init data graph. """ if params is None: for key in self.parameters_dict(): tensor = self.parameters_dict()[key].data if key not in self.parameter_layout_dict: logger.info("layout dict does not contain the key %s", key) continue if self.parameters_dict()[key].sliced: logger.debug("Param %s is already sliced.", key) continue layout = self.parameter_layout_dict[key] new_tensor = _load_tensor_by_layout(tensor, layout) self.parameters_dict()[key].set_parameter_data(new_tensor) self.parameters_dict()[key].sliced = True elif isinstance(params, OrderedDict): for key in params: tensor = params[key].data if key not in self.parameter_layout_dict: logger.info("layout dict does not contain the key %s", key) continue if params[key].sliced: logger.debug("Param %s is already sliced.", key) continue layout = self.parameter_layout_dict[key] new_tensor = _load_tensor_by_layout(tensor, layout) params[key].set_parameter_data(new_tensor) params[key].sliced = True else: raise TypeError('Parameters need OrderedDict type, but got {}'. format(type(params)))
def _load_inputs(self, *inputs): """ Slice inputs tensors by parallel strategies. Args: inputs (Function or Cell): inputs of construct method. """ parallel_inputs_run = [] if len(inputs) > self._construct_inputs_num: raise ValueError('Len of inputs: {} is bigger than self._construct_inputs_num: {}.'. format(len(inputs), self._construct_inputs_num)) for i, tensor in enumerate(inputs): key = self._construct_inputs_names[i] # if input is not used, self.parameter_layout_dict may not contain the key if key not in self.parameter_layout_dict: logger.warning("layout dict does not contain the key %s", key) parallel_inputs_run.append(tensor) else: layout = self.parameter_layout_dict[key] new_tensor = _load_tensor_by_layout(tensor, layout) parallel_inputs_run.append(new_tensor) return tuple(parallel_inputs_run)
[docs] def set_parallel_input_with_inputs(self, *inputs): """ Slice inputs tensors by parallel strategies, and set the sliced inputs to `_parallel_input_run` Args: inputs (tuple): inputs of construct method. """ self._parallel_inputs_run = self._load_inputs(*inputs)
def _get_construct_inputs_number_and_name(self): """Compute self._construct_inputs_names and self._construct_inputs_num""" import inspect from mindspore._extends.parse.parser import get_parse_method_of_class fn = get_parse_method_of_class(self) inspect.getfullargspec(fn) self._construct_inputs_num = fn.__code__.co_argcount self._construct_inputs_names = fn.__code__.co_varnames assert self._construct_inputs_num > 0 assert self._construct_inputs_names[0] == 'self' assert self._construct_inputs_num - 1 <= len(self._construct_inputs_names) self._construct_inputs_names = self._construct_inputs_names[1:self._construct_inputs_num] self._construct_inputs_num = self._construct_inputs_num - 1
[docs] def compile(self, *inputs): """ Compiles cell. Args: inputs (tuple): Input parameters. """ _executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode)
[docs] def compile_and_run(self, *inputs): """ Compiles and runs cell. Args: inputs (tuple): Input parameters. Returns: Object, the result of executing. """ _executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode) if self._auto_parallel_mode: if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag: # get parallel inputs in sink mode, parallel inputs set in _executor.compile parallel_inputs_run = self._parallel_inputs_run else: # set parallel inputs in normal mode self._parallel_inputs_run = self._load_inputs(*inputs) parallel_inputs_run = self._parallel_inputs_run return _executor(self, *parallel_inputs_run, phase=self.phase) return _executor(self, *inputs, phase=self.phase)
[docs] def exec_checkpoint_graph(self): """Executes saving checkpoint graph operation.""" _executor(self, phase='save')
[docs] def insert_param_to_cell(self, param_name, param, check_name=True): """ Adds a parameter to the current cell. Inserts a parameter with given name to the cell. Please refer to the usage in source code of `mindspore.nn.Cell.__setattr__`. Args: param_name (str): Name of the parameter. param (Parameter): Parameter to be inserted to the cell. check_name (bool): Determines whether the name input is compatible. Default: True. Raises: KeyError: If the name of parameter is null or contains dot. AttributeError: If user did not call init() first. TypeError: If the type of parameter is not Parameter. """ if not param_name: raise KeyError("The name of parameter should not be null.") if check_name and '.' in param_name: raise KeyError("The name of parameter should not contain \".\"") if '_params' not in self.__dict__: raise AttributeError("You need call init() first.") if hasattr(self, param_name) and param_name not in self._params: raise KeyError("Duplicated parameter name '{}'.".format(param_name)) if not isinstance(param, Parameter) and param is not None: raise TypeError("The type of parameter should be 'Parameter' if not None.") self._params[param_name] = param
[docs] def insert_child_to_cell(self, child_name, child): """ Adds a child cell to the current cell. Inserts a subcell with given name to current cell. Args: child_name (str): Name of the child cell. child (Cell): The child cell to be inserted. Raises: KeyError: Child Cell's name is incorrect or duplicated with the other child name. TypeError: Child Cell's type is incorrect. """ if not child_name or '.' in child_name: raise KeyError("Child cell name is incorrect.") if hasattr(self, child_name) and child_name not in self._cells: raise KeyError("Duplicate child name '{}'.".format(child_name)) if not isinstance(child, Cell) and child is not None: raise TypeError("Child cell type is incorrect.") self._cells[child_name] = child
[docs] def construct(self, *inputs): """ Defines the computation to be performed. This method should be overridden by all subclasses. Note: The inputs of the top cell only allow Tensor. Other types (tuple, list, int etc.) are forbidden. Returns: Tensor, returns the computed result. """ raise NotImplementedError
[docs] def init_parameters_data(self, recurse=True, auto_parallel_mode=False): """Init parameters' data.""" for param in self.get_parameters(expand=recurse): if not auto_parallel_mode: param.init_data() elif param.name not in self.parameter_layout_dict: logger.debug("Layout dict does not contain the key %s.", param.name) param.init_data(set_sliced=True) else: layout = self.parameter_layout_dict[param.name] param.init_data(layout, set_sliced=True)
[docs] def parameters_dict(self, recurse=True): """ Gets parameters dictionary. Gets the parameters dictionary of this cell. Args: recurse (bool): Whether contains the parameters of subcells. Default: True. Returns: OrderedDict, return parameters dictionary. """ param_dict = OrderedDict() for param in self.get_parameters(expand=recurse): param_dict[param.name] = param return param_dict
def parameters_broadcast_dict(self, recurse=True): param_dict = OrderedDict() for param in self.get_parameters(expand=recurse): if param.layerwise_parallel is False: param_dict[param.name] = param if not param_dict: return None return param_dict
[docs] def update_parameters_name(self, prefix='', recurse=True): """ Updates the names of parameters with given prefix string. Adds the given prefix to the names of parameters. Args: prefix (str): The prefix string. recurse (bool): Whether contains the parameters of subcells. Default: True. """ _check_str_by_regular(prefix) for name, param in self.parameters_and_names(expand=recurse): if prefix != '': param.is_init = False param.name = prefix + name
[docs] def trainable_params(self, recurse=True): """ Returns all trainable parameters. Returns a list of all trainable parmeters. Args: recurse (bool): Whether contains the trainable parameters of subcells. Default: True. Returns: List, the list of trainable parameters. """ return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
[docs] def untrainable_params(self, recurse=True): """ Returns all untrainable parameters. Returns a list of all untrainable parmeters. Args: recurse (bool): Whether contains the untrainable parameters of subcells. Default: True. Returns: List, the list of untrainable parameters. """ return list(filter(lambda x: not x.requires_grad, self.get_parameters(expand=recurse)))
[docs] def get_parameters(self, expand=True): """ Returns an iterator over cell parameters. Yields parameters of this cell. If `expand` is True, yield parameters of this cell and all subcells. Args: expand (bool): If True, yields parameters of this cell and all subcells. Otherwise, yields only parameters that are direct members of this cell. Default: True. Examples: >>> net = Net() >>> for item in net.get_parameters(): >>> print(item) """ for _, param in self.parameters_and_names(expand=expand): yield param
def check_names(self): names = set("") for value, param in self.parameters_and_names(): if param.name in names: raise ValueError("The value of {} is {}, its name '{}' already exists.". format(value, param, param.name)) names.add(param.name)
[docs] def parameters_and_names(self, name_prefix='', expand=True): """ Returns an iterator over cell parameters. Includes the parameter's name and itself. Args: name_prefix (str): Namespace. Default: ''. expand (bool): If True, yields parameters of this cell and all subcells. Otherwise, yields only parameters that are direct members of this cell. Default: True. Examples: >>> n = Net() >>> names = [] >>> for m in n.parameters_and_names(): >>> if m[0]: >>> names.append(m[0]) """ cells = [] if expand: cells = self.cells_and_names(name_prefix=name_prefix) else: cells.append((name_prefix, self)) params_set = set() for cell_name, cell in cells: params = cell._params.items() for par_name, par in params: if par and par not in params_set: params_set.add(par) par_new_name = par_name if cell_name: par_new_name = cell_name + '.' + par_new_name yield par_new_name, par
[docs] def cells_and_names(self, cells=None, name_prefix=''): """ Returns an iterator over all cells in the network. Includes the cell's name and itself. Args: cells (str): Cells to iterate over. Default: None. name_prefix (str): Namespace. Default: ''. Examples: >>> n = Net() >>> names = [] >>> for m in n.cells_and_names(): >>> if m[0]: >>> names.append(m[0]) """ t_cells = cells if cells else set() if self in t_cells: return t_cells.add(self) yield name_prefix, self for name, cell in self._cells.items(): if cell: cells_name_prefix = name if name_prefix: cells_name_prefix = name_prefix + '.' + cells_name_prefix for ele in cell.cells_and_names(t_cells, cells_name_prefix): yield ele
[docs] def cells(self): """Returns an iterator over immediate cells.""" return self.name_cells().values()
def _set_scope(self, name): """Sets the name on the first time.""" if self._scope is None: self._scope = name def _children_scope_recursive(self, parent_prefix='Default'): """Generates the scope of each layer of the network recursively.""" reserve_class_name_in_scope = context.get_context("reserve_class_name_in_scope") for name, cell in self.name_cells().items(): yield parent_prefix + "/" + name + (("-" + cell.__class__.__name__) if reserve_class_name_in_scope else ""), cell for name, cell in self.name_cells().items(): for key, value in cell._children_scope_recursive(parent_prefix + "/" + name + (("-" + cell.__class__.__name__) if reserve_class_name_in_scope else "")): yield key, value
[docs] def get_scope(self): """Returns the scope of a cell object in one network.""" return self._scope
[docs] def generate_scope(self): """Generate the scope for every cell object in the network.""" for name, cell in self._children_scope_recursive(): cell._set_scope(name)
[docs] def name_cells(self): """ Returns an iterator over all cells in the network. Include name of the cell and cell itself. """ value_set = set() cells = OrderedDict() for name, cell in self._cells.items(): if cell is not None and cell not in value_set: value_set.add(cell) cells[name] = cell return cells
def add_flags(self, **flags): if not hasattr(self, "_mindspore_flags"): self._mindspore_flags = {} self._mindspore_flags.update({**flags}) self.__dict__.update({**flags}) return self def add_flags_recursive(self, **flags): self.add_flags(**flags) if hasattr(self, '_cell_init_args'): self._cell_init_args += str({**flags}) for cell in self.cells(): cell.add_flags_recursive(**flags) return self def get_flags(self): if not hasattr(self, "_mindspore_flags"): self._mindspore_flags = {} return self._mindspore_flags
[docs] def to_float(self, dst_type): """ Add cast on all inputs of cell and child cells to run with certain float type. If `dst_type is mindspore.dtype.float16`, all the inputs of Cell including input, Parameter, Tensor as const will be cast to float16. Please refer to the usage in source code of `mindspore.train.amp.build_train_network`. Note: Call multiple times will overwrite the previous. Args: dst_type (:class:`mindspore.dtype`): Transfer Cell to Run with dst_type. dst_type can be `mindspore.dtype.float16` or `mindspore.dtype.float32`. Raises: ValueError: If dst_type is not float32 or float16. """ if dst_type not in (mstype.float16, mstype.float32): raise ValueError("dst_type should inside float32 or float16.") flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32} self.add_flags_recursive(**flags) return self
def set_grad(self, mode=True): self.requires_grad = mode return self
[docs] def set_train(self, mode=True): """ Sets the cell to training mode. The cell itself and all children cells will be set to training mode. Args: mode (bool): Specifies whether the model is training. Default: True. """ if mode is False: self._phase = 'predict' else: self._phase = 'train' self.add_flags_recursive(training=mode) return self
[docs] def set_broadcast_flag(self, mode=True): """ Set the cell to data_parallel mode. The cell can be accessed as an attribute using the given name. Args: mode (bool): Specifies whether the model is data_parallel. Default: True. """ self.add_flags_recursive(broadcast_flag=mode) return self
[docs] def set_auto_parallel(self): """ Set the cell to auto parallel mode. Note: If a cell needs to use auto parallel or semi auto parallel mode for training, evaluation or prediction, this interface needs to be called for the cell. """ self._auto_parallel_mode = True self.add_flags(auto_parallel=True) self._get_construct_inputs_number_and_name()
def _hook_construct(self, *inputs): """Hook construct method to replace original construct method when hook function enabled.""" inputs = self._backward_hook(*inputs) inputs = self.construct(inputs) outputs = self._backward_hook(inputs) return outputs
[docs] def register_backward_hook(self, fn): """ Set the cell backward hook function. Note: fn should be defined as following code shows, `cell_name` is the name of registered cell, `grad_input` is gradient passed to the cell, `grad_output` is the gradient computed and pass to next cell or primitve, which may be modified and return. >>> hook_fn(cell_name, grad_input, grad_output) -> Tensor or None Args: fn (function): Specifies the hook function with grad as input. """ self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")") self.enable_hook = True
[docs]class GraphKernel(Cell): """ Base class for GraphKernel. A `GraphKernel` a composite of basic primitives and can be compiled into a fused kernel automaticly when context.set_context(enable_graph_kernel=True). Examples: >>> class Relu(GraphKernel): >>> def __init__(self): >>> super(Relu, self).__init__() >>> self.max = P.Maximum() >>> >>> def construct(self, x): >>> return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x) """ def __init__(self, auto_prefix=True, pips=None): super(GraphKernel, self).__init__(auto_prefix, pips) class_name = self.__class__.__name__ self.add_flags(graph_kernel=class_name) def construct(self): raise NotImplementedError