mindspore.nn.SequentialCell

class mindspore.nn.SequentialCell(*args)[source]

Sequential Cell container. For more details about Cell, please refer to Cell.

A list of Cells will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of cells can also be passed in.

Note

SequentialCell and torch.nn.ModuleList are different, ModuleList is a list for storing modules. However, the layers in a Sequential are connected in a cascading way.

Parameters

args (list, OrderedDict) – List or OrderedDict of subclass of Cell.

Inputs:
  • x (Tensor) - Tensor with shape according to the first Cell in the sequence.

Outputs:

Tensor, the output Tensor with shape depending on the input x and defined sequence of Cells.

Raises

TypeError – If the type of the args is not list or OrderedDict.

Supported Platforms:

Ascend GPU CPU

Examples

>>> from mindspore import Tensor
>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>>
>>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid', weight_init="ones")
>>> relu = nn.ReLU()
>>> seq = nn.SequentialCell([conv, relu])
>>> x = Tensor(np.ones([1, 3, 4, 4]), dtype = mindspore.float32)
>>> output = seq(x)
>>> print(output)
[[[[27. 27.]
   [27. 27.]]
  [[27. 27.]
   [27. 27.]]]]
>>> from collections import OrderedDict
>>> d = OrderedDict()
>>> d["conv"] = conv
>>> d["relu"] = relu
>>> seq = nn.SequentialCell(d)
>>> x = Tensor(np.ones([1, 3, 4, 4]), dtype=mindspore.float32)
>>> output = seq(x)
>>> print(output)
[[[[27. 27.]
   [27. 27.]]
  [[27. 27.]
   [27. 27.]]]]
append(cell)[source]

Appends a given Cell to the end of the list.

Parameters

cell (Cell) – The Cell to be appended.

Examples

>>> from mindspore import Tensor
>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>>
>>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid', weight_init="ones")
>>> bn = nn.BatchNorm2d(2)
>>> relu = nn.ReLU()
>>> seq = nn.SequentialCell([conv, bn])
>>> seq.append(relu)
>>> x = Tensor(np.ones([1, 3, 4, 4]), dtype=mindspore.float32)
>>> output = seq(x)
>>> print(output)
[[[[26.999863 26.999863]
   [26.999863 26.999863]]
  [[26.999863 26.999863]
   [26.999863 26.999863]]]]