# 模型训练¶

## 超参（Hyper-parametric）¶

$w_{t+1}=w_{t}-\eta \frac{1}{n} \sum_{x \in \mathcal{B}} \nabla l\left(x, w_{t}\right)$

[1]:

epochs = 10
batch_size = 32
momentum = 0.9
learning_rate = 1e-2


## 损失函数¶

$\text { L1 Loss Function }=\sum_{i=1}^{n}\left|y_{true}-y_{predicted}\right|$

mindspore.nn.loss也提供了许多其他常用的损失函数，如SoftmaxCrossEntropyWithLogitsMSELossSmoothL1Loss等。

[2]:

import numpy as np
import mindspore.nn as nn
import mindspore as ms

loss = nn.L1Loss()
output_data = ms.Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32))
target_data = ms.Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32))

print(loss(output_data, target_data))

1.5


## 优化器函数¶

MindSpore所有优化逻辑都封装在Optimizer对象中，在这里，我们使用Momentum优化器。mindspore.nn也提供了许多其他常用的优化器函数，如AdamSGDRMSProp等。

[3]:

from mindspore import nn
from mindvision.classification.models import lenet

net = lenet(num_classes=10, pretrained=False)
optim = nn.Momentum(net.trainable_params(), learning_rate, momentum)


## 模型训练¶

1. 构建数据集。

2. 定义神经网络。

3. 定义超参、损失函数及优化器。

4. 输入训练轮次和数据集进行训练。

[4]:

import mindspore.nn as nn
import mindspore as ms

from mindvision.classification.dataset import Mnist
from mindvision.classification.models import lenet
from mindvision.engine.callback import LossMonitor

# 1. 构建数据集
download_train = Mnist(path="./mnist", split="train", batch_size=batch_size, repeat_num=1, shuffle=True, resize=32, download=True)
dataset_train = download_train.run()

# 2. 定义神经网络
network = lenet(num_classes=10, pretrained=False)
# 3.1 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 3.2 定义优化器函数
net_opt = nn.Momentum(network.trainable_params(), learning_rate=learning_rate, momentum=momentum)
# 3.3 初始化模型参数
model = ms.Model(network, loss_fn=net_loss, optimizer=net_opt, metrics={'acc'})

# 4. 对神经网络执行训练
model.train(epochs, dataset_train, callbacks=[LossMonitor(learning_rate, 1875)])

Epoch:[  0/ 10], step:[ 1875/ 1875], loss:[0.189/1.176], time:2.254 ms, lr:0.01000
Epoch time: 4286.163 ms, per step time: 2.286 ms, avg loss: 1.176
Epoch:[  1/ 10], step:[ 1875/ 1875], loss:[0.085/0.080], time:1.895 ms, lr:0.01000
Epoch time: 4064.532 ms, per step time: 2.168 ms, avg loss: 0.080
Epoch:[  2/ 10], step:[ 1875/ 1875], loss:[0.021/0.054], time:1.901 ms, lr:0.01000
Epoch time: 4194.333 ms, per step time: 2.237 ms, avg loss: 0.054
Epoch:[  3/ 10], step:[ 1875/ 1875], loss:[0.284/0.041], time:2.130 ms, lr:0.01000
Epoch time: 4252.222 ms, per step time: 2.268 ms, avg loss: 0.041
Epoch:[  4/ 10], step:[ 1875/ 1875], loss:[0.003/0.032], time:2.176 ms, lr:0.01000
Epoch time: 4216.039 ms, per step time: 2.249 ms, avg loss: 0.032
Epoch:[  5/ 10], step:[ 1875/ 1875], loss:[0.003/0.027], time:2.205 ms, lr:0.01000
Epoch time: 4400.771 ms, per step time: 2.347 ms, avg loss: 0.027
Epoch:[  6/ 10], step:[ 1875/ 1875], loss:[0.000/0.024], time:1.973 ms, lr:0.01000
Epoch time: 4554.252 ms, per step time: 2.429 ms, avg loss: 0.024
Epoch:[  7/ 10], step:[ 1875/ 1875], loss:[0.008/0.022], time:2.048 ms, lr:0.01000
Epoch time: 4361.135 ms, per step time: 2.326 ms, avg loss: 0.022
Epoch:[  8/ 10], step:[ 1875/ 1875], loss:[0.000/0.018], time:2.130 ms, lr:0.01000
Epoch time: 4547.597 ms, per step time: 2.425 ms, avg loss: 0.018
Epoch:[  9/ 10], step:[ 1875/ 1875], loss:[0.008/0.017], time:2.135 ms, lr:0.01000
Epoch time: 4601.861 ms, per step time: 2.454 ms, avg loss: 0.017