mindspore.nn.CellDict
- class mindspore.nn.CellDict(*args, **kwargs)[源代码]
构造Cell字典。关于 Cell 的介绍,可参考
mindspore.nn.Cell
。CellDict 可以像普通Python字典一样使用。
- 参数:
args (iterable,可选) - 一个可迭代对象,通过它可迭代若干键值对(key, Cell),其中键值对的类型为(string, Cell);或者是一个从string到Cell的映射(字典)。Cell的类型不能为CellDict, CellList或者SequentialCell。key不能与类Cell中的属性重名,不能包含‘.’,不能是一个空串。通过类型为string的键可以在CellDict中查找其对应的Cell。
kwargs (dict) - 为待扩展的关键字参数预留。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import collections >>> from collections import OrderedDict >>> import mindspore as ms >>> import numpy as np >>> from mindspore import Tensor, nn >>> >>> cell_dict = nn.CellDict({'conv': nn.Conv2d(10, 6, 5), ... 'relu': nn.ReLU(), ... 'max_pool2d': nn.MaxPool2d(kernel_size=4, stride=4)}) >>> print(len(cell_dict)) 3 >>> cell_dict.clear() >>> print(len(cell_dict)) 0 >>> ordered_cells = OrderedDict([('conv', nn.Conv2d(10, 6, 5, pad_mode='valid')), ... ('relu', nn.ReLU()), ... ('max_pool2d', nn.MaxPool2d(kernel_size=2, stride=2))]) >>> cell_dict.update(ordered_cells) >>> x = Tensor(np.ones([1, 10, 6, 10]), ms.float32) >>> for cell in cell_dict.values(): ... x = cell(x) >>> print(x.shape) (1, 6, 1, 3) >>> x = Tensor(np.ones([1, 10, 6, 10]), ms.float32) >>> for item in cell_dict.items(): ... x = item[1](x) >>> print(x.shape) (1, 6, 1, 3) >>> print(cell_dict.keys()) odict_keys(['conv', 'relu', 'max_pool2d']) >>> pop_cell = cell_dict.pop('conv') >>> x = Tensor(np.ones([1, 10, 6, 5]), ms.float32) >>> x = pop_cell(x) >>> print(x.shape) (1, 6, 2, 1) >>> print(len(cell_dict)) 2
- pop(key)[源代码]
从CellDict中移除键为 key 的Cell,并将这个Cell返回。
- 参数:
key (string) - 从CellDict中移除的Cell的键。
- 异常:
KeyError - key 对应的Cell在CellDict中不存在。
- update(cells)[源代码]
使用映射或者可迭代对象中的键值对来更新CellDict中已存在的Cell。
- 参数:
cells (iterable) - 一个可迭代对象,通过它可迭代若干键值对(key, Cell),其中键值对的类型为(string, Cell);或者是一个从string到Cell的映射(字典)。Cell的类型不能为CellDict, CellList或者SequentialCell。key不能与类Cell中的属性重名,不能包含‘.’,不能是一个空串。
说明
如果 cells 是一个CellDict、一个OrderedDict或者是一个包含键值对的可迭代对象,那么新增元素的顺序在CellDict中仍会被保留。
- 异常:
TypeError - 如果 cells 不是一个可迭代对象。
TypeError - 如果 cells 中的键值对不是可迭代对象。
ValueError - 如果 cells 中键值对的长度不是2。
TypeError - 如果 cells 中的cell是None。
TypeError - 如果 cells 中的cell的类型不是Cell。
TypeError - 如果 cells 中的cell的类型是CellDict,CellList或者SequentialCell。
TypeError - 如果 cells 中的key的类型不是String类型。
KeyError - 如果 cells 中的key与类Cell中的属性重名。
KeyError - 如果 cells 中的key包含“.”。
KeyError - 如果 cells 中的key是一个空串。