mindspore.nn.CellDict

查看源文件
class mindspore.nn.CellDict(*args, **kwargs)[源代码]

构造Cell字典。关于 Cell 的介绍,可参考 mindspore.nn.Cell

CellDict 可以像普通Python字典一样使用。

参数:
  • args (iterable,可选) - 一个可迭代对象,通过它可迭代若干键值对(key, Cell),其中键值对的类型为(string, Cell);或者是一个从string到Cell的映射(字典)。Cell的类型不能为CellDict, CellList或者SequentialCell。key不能与类Cell中的属性重名,不能包含‘.’,不能是一个空串。通过类型为string的键可以在CellDict中查找其对应的Cell。

  • kwargs (dict) - 为待扩展的关键字参数预留。

支持平台:

Ascend GPU CPU

样例:

>>> import collections
>>> from collections import OrderedDict
>>> import mindspore as ms
>>> import numpy as np
>>> from mindspore import Tensor, nn
>>>
>>> cell_dict = nn.CellDict({'conv': nn.Conv2d(10, 6, 5),
...                          'relu': nn.ReLU(),
...                          'max_pool2d': nn.MaxPool2d(kernel_size=4, stride=4)})
>>> print(len(cell_dict))
3
>>> cell_dict.clear()
>>> print(len(cell_dict))
0
>>> ordered_cells = OrderedDict([('conv', nn.Conv2d(10, 6, 5, pad_mode='valid')),
...                              ('relu', nn.ReLU()),
...                              ('max_pool2d', nn.MaxPool2d(kernel_size=2, stride=2))])
>>> cell_dict.update(ordered_cells)
>>> x = Tensor(np.ones([1, 10, 6, 10]), ms.float32)
>>> for cell in cell_dict.values():
...     x = cell(x)
>>> print(x.shape)
(1, 6, 1, 3)
>>> x = Tensor(np.ones([1, 10, 6, 10]), ms.float32)
>>> for item in cell_dict.items():
...     x = item[1](x)
>>> print(x.shape)
(1, 6, 1, 3)
>>> print(cell_dict.keys())
odict_keys(['conv', 'relu', 'max_pool2d'])
>>> pop_cell = cell_dict.pop('conv')
>>> x = Tensor(np.ones([1, 10, 6, 5]), ms.float32)
>>> x = pop_cell(x)
>>> print(x.shape)
(1, 6, 2, 1)
>>> print(len(cell_dict))
2
clear()[源代码]

移除CellDict中的所有Cell。

items()[源代码]

返回包含CellDict中所有键值对的可迭代对象。

返回:

一个可迭代对象。

keys()[源代码]

返回包含CellDict中所有键的可迭代对象。

返回:

一个可迭代对象。

pop(key)[源代码]

从CellDict中移除键为 key 的Cell,并将这个Cell返回。

参数:
  • key (string) - 从CellDict中移除的Cell的键。

异常:
  • KeyError - key 对应的Cell在CellDict中不存在。

update(cells)[源代码]

使用映射或者可迭代对象中的键值对来更新CellDict中已存在的Cell。

参数:
  • cells (iterable) - 一个可迭代对象,通过它可迭代若干键值对(key, Cell),其中键值对的类型为(string, Cell);或者是一个从string到Cell的映射(字典)。Cell的类型不能为CellDict, CellList或者SequentialCell。key不能与类Cell中的属性重名,不能包含‘.’,不能是一个空串。

说明

如果 cells 是一个CellDict、一个OrderedDict或者是一个包含键值对的可迭代对象,那么新增元素的顺序在CellDict中仍会被保留。

异常:
  • TypeError - 如果 cells 不是一个可迭代对象。

  • TypeError - 如果 cells 中的键值对不是可迭代对象。

  • ValueError - 如果 cells 中键值对的长度不是2。

  • TypeError - 如果 cells 中的cell是None。

  • TypeError - 如果 cells 中的cell的类型不是Cell。

  • TypeError - 如果 cells 中的cell的类型是CellDict,CellList或者SequentialCell。

  • TypeError - 如果 cells 中的key的类型不是String类型。

  • KeyError - 如果 cells 中的key与类Cell中的属性重名。

  • KeyError - 如果 cells 中的key包含“.”。

  • KeyError - 如果 cells 中的key是一个空串。

values()[源代码]

返回包含CellDict中所有Cell的可迭代对象。

返回:

一个可迭代对象。