自定义训练
MindSpore提供了model.train
接口来进行模型训练。使用方式可以参考初级教程-初学入门。此外,还可以使用TrainOneStepCell
,该接口当前支持GPU、Ascend环境。
作为高阶接口,model.train
封装了TrainOneStepCell
,可以直接利用设定好的网络、损失函数与优化器进行训练。用户也可以选择使用TrainOneStepCell
实现更加灵活的训练,例如控制训练数据集、实现多输入多输出网络、或自定义训练过程。
TrainOneStepCell说明
TrainOneStepCell
中包含三种入参:
network (Cell):参与训练的网络,当前仅接受单输出网络。
optimizer (Cell):所使用的优化器。
sens (Number):反向传播的缩放比例。
下面使用TrainOneStepCell
替换model.train
,实现简单的线性网络训练过程。
TrainOneStepCell使用示例
创建模型并生成数据
本小节详细解释说明可参考初级教程-初学入门。
定义网络LinearNet,内部有两层全连接层组成的网络, 包含5个入参和1个出参的神经网络。
[1]:
import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.dataset as ds
from mindspore import ParameterTuple
class LinearNet(nn.Cell):
def __init__(self):
super().__init__()
self.relu = nn.ReLU()
self.dense1 = nn.Dense(5, 32)
self.dense2 = nn.Dense(32, 1)
def construct(self, x):
x = self.dense1(x)
x = self.relu(x)
x = self.dense2(x)
return x
产生输入数据。
[2]:
np.random.seed(4)
class DatasetGenerator:
def __init__(self):
self.data = np.random.randn(5, 5).astype(np.float32)
self.label = np.random.randn(5, 1).astype(np.int32)
def __getitem__(self, index):
return self.data[index], self.label[index]
def __len__(self):
return len(self.data)
数据处理。
[3]:
# 对输入数据进行处理
dataset_generator = DatasetGenerator()
dataset = ds.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=True)
dataset = dataset.batch(32)
# 实例化网络
net = LinearNet()
定义TrainOneStepCell
在TrainOneStepCell
中,可以实现对训练过程的个性化设定。
[4]:
class TrainOneStepCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0):
"""参数初始化"""
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
# 使用tuple包装weight
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
# 定义梯度函数
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
def construct(self, data, label):
"""构建训练过程"""
weights = self.weights
loss = self.network(data, label)
# 为反向传播设定系数
sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(data, label, sens)
return loss, self.optimizer(grads)
网络训练
在使用TrainOneStepCell
时,需要利用WithLossCell
接口引入损失函数,共同完成训练过程。下面利用之前设定好的参数训练LinearNet网络,并获取loss值。
[5]:
# 设定损失函数
crit = nn.MSELoss()
# 设定优化器
opt = nn.Adam(params=net.trainable_params())
# 引入损失函数
net_with_criterion = nn.WithLossCell(net, crit)
# 自定义网络训练
train_net = TrainOneStepCell(net_with_criterion, opt)
# 获取训练过程数据
for d in dataset.create_dict_iterator():
for i in range(300):
train_net(d["data"], d["label"])
print(net_with_criterion(d["data"], d["label"]))
0.7998974
0.79927444
0.7986423
0.7979911
0.79732
... ...
0.042837422
0.041227795
0.039638687
... ...
9.276913e-06
8.4145695e-06
7.625091e-06
6.904066e-06
6.2513377e-06