# 网络搭建 [](https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/source_zh_cn/migration_guide/model_development/model_and_cell.md) ## 基础逻辑 PyTorch和MindSpore的基础逻辑如下图所示:  可以看到,PyTorch和MindSpore在实现流程中一般都需要网络定义、正向计算、反向计算、梯度更新等步骤。 - 网络定义:在网络定义中,一般会定义出需要的前向网络,损失函数和优化器。在Net()中定义前向网络,PyTorch的网络继承nn.Module;类似地,MindSpore的网络继承nn.Cell。在MindSpore中,除了使用MindSpore中提供的损失函数和优化器外,用户还可以使用自定义的优化器。可参考[模型模块自定义](https://mindspore.cn/tutorials/zh-CN/r2.3.0rc1/advanced/modules.html)。可以使用functional/nn等接口拼接需要的前向网络、损失函数和优化器。 - 正向计算:运行实例化后的网络,可以得到logit,将logit和target作为输入计算loss。需要注意的是,如果正向计算的函数有多个输出,在反向计算时需要注意多个输出对于计算结果的影响。 - 反向计算:得到loss后,我们可以进行反向计算。在PyTorch中可使用loss.backward()计算梯度,在MindSpore中,先用mindspore.grad()定义出反向传播方程net_backward,再将输入传入net_backward中,即可计算梯度。如果正向计算的函数有多个输出,在反向计算时,可将has_aux设置为True,即可保证只有第一个输出参与求导,其它输出值将直接返回。对于反向计算中接口用法区别详见[自动微分对比](./gradient.md)。 - 梯度更新:将计算后的梯度更新到网络的Parameters中。在PyTorch中使用optim.step();在MindSpore中,将Parameter的梯度传入定义好的optim中,即可完成梯度更新。 ## 网络基本构成单元 Cell MindSpore的网络搭建主要使用[Cell](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc1/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell)进行图的构造,用户需要定义一个类继承 `Cell` 这个基类,在 `init` 里声明需要使用的API及子模块,在 `construct` 里进行计算, `Cell` 在 `GRAPH_MODE` (静态图模式)下将编译为一张计算图,在 `PYNATIVE_MODE` (动态图模式)下作为神经网络的基础模块。 PyTorch 和 MindSpore 基本的 `Cell` 搭建过程如下所示:
PyTorch | MindSpore |
```python import torch.nn as torch_nn class MyCell_pt(torch_nn.Module): def __init__(self, forward_net): super(MyCell_pt, self).__init__() self.net = forward_net self.relu = torch_nn.ReLU() def forward(self, x): y = self.net(x) return self.relu(y) inner_net_pt = torch_nn.Conv2d(120, 240, kernel_size=4, bias=False) pt_net = MyCell_pt(inner_net_pt) for i in pt_net.parameters(): print(i.shape) ``` 运行结果: ```text torch.Size([240, 120, 4, 4]) ``` |
```python import mindspore.nn as nn import mindspore.ops as ops class MyCell(nn.Cell): def __init__(self, forward_net): super(MyCell, self).__init__(auto_prefix=True) self.net = forward_net self.relu = ops.ReLU() def construct(self, x): y = self.net(x) return self.relu(y) inner_net = nn.Conv2d(120, 240, 4, has_bias=False) my_net = MyCell(inner_net) print(my_net.trainable_params()) ``` 运行结果: ```text [Parameter (name=net.weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)] ``` |
PyTorch 设置模型数据类型 | MindSpore 设置模型数据类型 |
```python import torch import torch.nn as nn class Network(nn.Module): def __init__(self): super(Network, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(3, 12, kernel_size=3, padding=1), nn.BatchNorm2d(12), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.layer2 = nn.Sequential( nn.Conv2d(12, 4, kernel_size=3, padding=1), nn.BatchNorm2d(4), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.pool = nn.AdaptiveMaxPool2d((5, 5)) self.fc = nn.Linear(100, 10) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.pool(x) x = x.view(x.size(0), -1) out = self.fc(x) return out net = Network() net = net.to(torch.float32) for name, module in net.named_modules(): if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): module.to(torch.float32) loss = nn.CrossEntropyLoss(reduction='mean') loss = loss.to(torch.float32) ``` |
```python import mindspore as ms from mindspore import nn # 定义模型 class Network(nn.Cell): def __init__(self): super().__init__() self.layer1 = nn.SequentialCell([ nn.Conv2d(3, 12, kernel_size=3, pad_mode='pad', padding=1), nn.BatchNorm2d(12), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ]) self.layer2 = nn.SequentialCell([ nn.Conv2d(12, 4, kernel_size=3, pad_mode='pad', padding=1), nn.BatchNorm2d(4), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ]) self.pool = nn.AdaptiveMaxPool2d((5, 5)) self.fc = nn.Dense(100, 10) def construct(self, x): x = self.layer1(x) x = self.layer2(x) x = self.pool(x) x = x.view((-1, 100)) out = nn.Dense(x) return out net = Network() net.to_float(ms.float16) # 将net里所有的操作加float16的标志,框架会在编译时在输入加cast方法 for _, cell in net.cells_and_names(): if isinstance(cell, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): cell.to_float(ms.float32) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean').to_float(ms.float32) net_with_loss = nn.WithLossCell(net, loss_fn=loss) ``` |
PyTorch | MindSpore |
```python import torch.nn as nn net = nn.Linear(2, 1) for name, param in net.named_parameters(): print("Parameter Name:", name) for name, param in net.named_parameters(): if "bias" in name: param.requires_grad = False for name, param in net.named_parameters(): if param.requires_grad: print("Parameter Name:", name) ``` 运行结果: ```text Parameter Name: weight Parameter Name: bias Parameter Name: weight ``` |
```python import mindspore.nn as nn net = nn.Dense(2, 1, has_bias=True) print(net.trainable_params()) for param in net.trainable_params(): param_name = param.name if "bias" in param_name: param.requires_grad = False print(net.trainable_params()) ``` 运行结果: ```text [Parameter (name=weight, shape=(1, 2), dtype=Float32, requires_grad=True), Parameter (name=bias, shape=(1,), dtype=Float32, requires_grad=True)] [Parameter (name=weight, shape=(1, 2), dtype=Float32, requires_grad=True)] ``` |
PyTorch | MindSpore |
```python import torch import torch.nn as nn linear_layer = nn.Linear(2, 1, bias=True) linear_layer.weight.data.fill_(1.0) linear_layer.bias.data.zero_() print("Original linear layer parameters:") print(linear_layer.weight) print(linear_layer.bias) torch.save(linear_layer.state_dict(), 'linear_layer_params.pth') new_linear_layer = nn.Linear(2, 1, bias=True) new_linear_layer.load_state_dict(torch.load('linear_layer_params.pth')) # 打印加载后的Parameter,应该和原始Parameter一样 print("Loaded linear layer parameters:") print(new_linear_layer.weight) print(new_linear_layer.bias) ``` 运行结果: ```text Original linear layer parameters: Parameter containing: tensor([[1., 1.]], requires_grad=True) Parameter containing: tensor([0.], requires_grad=True) Loaded linear layer parameters: Parameter containing: tensor([[1., 1.]], requires_grad=True) Parameter containing: tensor([0.], requires_grad=True) ``` |
```python import mindspore as ms import mindspore.ops as ops import mindspore.nn as nn net = nn.Dense(2, 1, has_bias=True) for param in net.get_parameters(): print(param.name, param.data.asnumpy()) ms.save_checkpoint(net, "dense.ckpt") dense_params = ms.load_checkpoint("dense.ckpt") print(dense_params) new_params = {} for param_name in dense_params: print(param_name, dense_params[param_name].data.asnumpy()) new_params[param_name] = ms.Parameter(ops.ones_like(dense_params[param_name].data), name=param_name) ms.load_param_into_net(net, new_params) for param in net.get_parameters(): print(param.name, param.data.asnumpy()) ``` 运行结果: ```text weight [[-0.0042482 -0.00427286]] bias [0.] {'weight': Parameter (name=weight, shape=(1, 2), dtype=Float32, requires_grad=True), 'bias': Parameter (name=bias, shape=(1,), dtype=Float32, requires_grad=True)} weight [[-0.0042482 -0.00427286]] bias [0.] weight [[1. 1.]] bias [1.] ``` |
torch.nn.init | mindspore.common.initializer |
```python import torch x = torch.empty(2, 2) torch.nn.init.uniform_(x) ``` |
```python import mindspore from mindspore.common.initializer import initializer, Uniform x = initializer(Uniform(), [1, 2, 3], mindspore.float32) ``` |
PyTorch | MindSpore |
```python import torch.nn as nn class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) # 使用add_module添加子模块 self.add_module('conv3', nn.Conv2d(64, 128, 3, 1)) self.sequential_block = nn.Sequential( nn.ReLU(), nn.Conv2d(128, 256, 3, 1), nn.ReLU() ) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.sequential_block(x) return x module = MyModule() # 使用named_modules遍历所有子模块(包括直接和间接子模块) for name, module_instance in module.named_modules(): print(f"Module name: {name}, type: {type(module_instance)}") ``` 运行结果: ```text Module name: , type: |
```python from mindspore import nn class MyCell(nn.Cell): def __init__(self): super(MyCell, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) # 使用insert_child_to_cell添加子模块 self.insert_child_to_cell('conv3', nn.Conv2d(64, 128, 3, 1)) self.sequential_block = nn.SequentialCell( nn.ReLU(), nn.Conv2d(128, 256, 3, 1), nn.ReLU() ) def construct(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.sequential_block(x) return x module = MyCell() # 使用cells_and_names遍历所有子模块(包括直接和间接子模块) for name, cell_instance in module.cells_and_names(): print(f"Cell name: {name}, type: {type(cell_instance)}") ``` 运行结果: ```text Cell name: , type: |
PyTorch | MindSpore |
```python import torch torch_net = torch.nn.Linear(3, 4) torch_net.cpu() ``` |
```python import mindspore mindspore.set_context(device_target="CPU") ms_net = mindspore.nn.Dense(3, 4) ``` |
PyTorch | MindSpore |
```python def box_select_torch(box, iou_score): mask = iou_score > 0.3 return box[mask] ``` |
```python import mindspore as ms from mindspore import ops ms.set_seed(1) def box_select_ms(box, iou_score): mask = (iou_score > 0.3).expand_dims(1) return ops.masked_select(box, mask) ``` |
PyTorch | MindSpore |
```python import torch import torch.nn as torch_nn class ClassLoss_pt(torch_nn.Module): def __init__(self): super(ClassLoss_pt, self).__init__() self.con_loss = torch_nn.CrossEntropyLoss(reduction='none') # 使用 torch.topk 来获取前70%的正样本数据 def forward(self, pred, label): mask = label > 0 vaild_label = label * mask pos_num = torch.clamp(mask.sum() * 0.7, 1).int() con = self.con_loss(pred, vaild_label.long()) * mask loss, unused_value = torch.topk(con, k=pos_num) return loss.mean() ``` |
```python import mindspore as ms from mindspore import ops from mindspore import nn as ms_nn class ClassLoss_ms(ms_nn.Cell): def __init__(self): super(ClassLoss_ms, self).__init__() self.con_loss = ms_nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="none") self.sort_descending = ops.Sort(descending=True) # MindSpore目前不支持TopK的K是变量,转换思路,获取到第K大的值,然后通过该值获取到topk的mask def construct(self, pred, label): mask = label > 0 vaild_label = label * mask pos_num = ops.maximum(mask.sum() * 0.7, 1).astype(ms.int32) con = self.con_loss(pred, vaild_label.astype(ms.int32)) * mask con_sort, unused_value = self.sort_descending(con) con_k = con_sort[pos_num - 1] con_mask = (con >= con_k).astype(con.dtype) loss = con * con_mask return loss.sum() / con_mask.sum() ``` |