保存、加载与转化模型
概述
在模型训练或者加载模型的过程中,有时需要替换掉模型文件中某些优化器或者其他超参数以及分类函数中的全连接层改动,但是又不想改动太大,或者从0开始训练模型,针对这种情况,MindSpore提供了只调整模型部分权重的CheckPoint进阶用法,并将方法应用在模型调优过程中。
基础用法可参考:保存加载参数。
准备工作
本篇以LeNet网络为例子,介绍在MindSpore中对模型进行保存,加载和转化等操作方法。
在进行操作前,需做好如下准备好以下几个文件:
MNIST数据集。
LeNet网络的预训练模型文件
checkpoint-lenet_1-1875.ckpt
。数据增强文件
dataset_process.py
,使用其中的数据增强方法create_dataset
,可参考官网实现一个图片分类应用中定义的数据增强方法create_dataset
。定义LeNet网络。
执行下述代码,完成前3项准备工作。
[1]:
!mkdir -p ./datasets/MNIST_Data/train ./datasets/MNIST_Data/test
!wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-labels-idx1-ubyte
!wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-images-idx3-ubyte
!wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-labels-idx1-ubyte
!wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-images-idx3-ubyte
!wget https://mindspore-website.obs.myhuaweicloud.com/notebook/source-codes/dataset_process.py -N
!wget -N https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/checkpoint_lenet-1_1875.zip
!unzip -o checkpoint_lenet-1_1875.zip
定义LeNet网络模型,具体定义过程如下。
[2]:
from mindspore.common.initializer import Normal
import mindspore.nn as nn
class LeNet5(nn.Cell):
"""Lenet network structure."""
# define the operator required
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
# use the preceding operators to construct networks
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
高级用法
保存
手动保存CheckPoint
使用save_checkpoint
,手动保存CheckPoint文件。
应用场景:
保存网络的初始值。
手动保存指定网络。
执行以下代码,在对预训练模型checkpoint_lenet-1_1875.ckpt
训练过100个batch的数据集后,使用save_checkpoint
手动保存出模型mindspore_lenet.ckpt
。
[3]:
from mindspore import Model, load_checkpoint, save_checkpoint, load_param_into_net
from mindspore import context, Tensor
from dataset_process import create_dataset
import mindspore.nn as nn
network = LeNet5()
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
net_loss = nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
params = load_checkpoint("checkpoint_lenet-1_1875.ckpt")
load_param_into_net(network, params)
net_with_criterion = nn.WithLossCell(network, net_loss)
train_net = nn.TrainOneStepCell(net_with_criterion, net_opt)
train_net.set_train()
train_path = "./datasets/MNIST_Data/train"
ds_train = create_dataset(train_path)
count = 0
for item in ds_train.create_dict_iterator():
input_data = item["image"]
labels = item["label"]
train_net(input_data, labels)
count += 1
if count==100:
print(train_net.trainable_params())
save_checkpoint(train_net, "mindspore_lenet.ckpt")
break
[Parameter (name=conv1.weight), Parameter (name=conv2.weight), Parameter (name=fc1.weight), Parameter (name=fc1.bias), Parameter (name=fc2.weight), Parameter (name=fc2.bias), Parameter (name=fc3.weight), Parameter (name=fc3.bias), Parameter (name=learning_rate), Parameter (name=momentum), Parameter (name=moments.conv1.weight), Parameter (name=moments.conv2.weight), Parameter (name=moments.fc1.weight), Parameter (name=moments.fc1.bias), Parameter (name=moments.fc2.weight), Parameter (name=moments.fc2.bias), Parameter (name=moments.fc3.weight), Parameter (name=moments.fc3.bias)]
从上述打印信息可以看出mindspore_lenet.ckpt
的权重参数,包括了前向传播过程中LeNet网络中各隐藏层中的权重参数、学习率、优化率以及反向传播中优化各权重层的优化器函数的权重。
保存指定的Cell
使用方法:CheckpointConfig
类的saved_network
参数。
应用场景:
只保存推理网络模型的参数(不保存优化器的参数会使生成的CheckPoint文件大小减小一倍)。
保存子网的参数,用于Fine-tune(模型微调)任务。
在回调函数中使用方法CheckpointConfig
,并指定保存模型的Cell为network
即前向传播的LeNet网络。
[4]:
import os
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
ds_train = create_dataset(train_path)
epoch_size = 1
model = Model(train_net)
config_ck = CheckpointConfig(saved_network=network)
ckpoint = ModelCheckpoint(prefix="lenet", config=config_ck)
model.train(epoch_size, ds_train, callbacks=[ckpoint, LossMonitor(625)])
epoch: 1 step: 625, loss is 0.116291314
epoch: 1 step: 1250, loss is 0.09527888
epoch: 1 step: 1875, loss is 0.23090823
模型经过训练后,保存出模型文件lenet-1_1875.ckpt
。接下来对比指定保存的模型cell和原始模型在大小和具体权重有何区别。
[5]:
model_with_opt = os.path.getsize("./checkpoint_lenet-1_1875.ckpt") // 1024
params_without_change = load_checkpoint("./checkpoint_lenet-1_1875.ckpt")
print("with_opt size:", model_with_opt, "kB")
print(params_without_change)
print("\n=========after train===========\n")
model_without_opt = os.path.getsize("./lenet-1_1875.ckpt") // 1024
params_with_change = load_checkpoint("./lenet-1_1875.ckpt")
print("without_opt size:", model_without_opt, "kB")
print(params_with_change)
with_opt size: 482 kB
{'conv1.weight': Parameter (name=conv1.weight), 'conv2.weight': Parameter (name=conv2.weight), 'fc1.weight': Parameter (name=fc1.weight), 'fc1.bias': Parameter (name=fc1.bias), 'fc2.weight': Parameter (name=fc2.weight), 'fc2.bias': Parameter (name=fc2.bias), 'fc3.weight': Parameter (name=fc3.weight), 'fc3.bias': Parameter (name=fc3.bias), 'learning_rate': Parameter (name=learning_rate), 'momentum': Parameter (name=momentum), 'moments.conv1.weight': Parameter (name=moments.conv1.weight), 'moments.conv2.weight': Parameter (name=moments.conv2.weight), 'moments.fc1.weight': Parameter (name=moments.fc1.weight), 'moments.fc1.bias': Parameter (name=moments.fc1.bias), 'moments.fc2.weight': Parameter (name=moments.fc2.weight), 'moments.fc2.bias': Parameter (name=moments.fc2.bias), 'moments.fc3.weight': Parameter (name=moments.fc3.weight), 'moments.fc3.bias': Parameter (name=moments.fc3.bias)}
=========after train===========
without_opt size: 241 kB
{'conv1.weight': Parameter (name=conv1.weight), 'conv2.weight': Parameter (name=conv2.weight), 'fc1.weight': Parameter (name=fc1.weight), 'fc1.bias': Parameter (name=fc1.bias), 'fc2.weight': Parameter (name=fc2.weight), 'fc2.bias': Parameter (name=fc2.bias), 'fc3.weight': Parameter (name=fc3.weight), 'fc3.bias': Parameter (name=fc3.bias)}
训练后,保存出来的模型lenet-1_1875.ckpt
,模型权重文件大小为241kB,跟原始完整模型大小482kB相比,整体减少了将近一半;
具体对比模型中的参数,可以看出lenet-1_1875.ckpt
中参数相比checkpoint_lenet-1_1875.ckpt
减少了学习率、优化率和反向优化等相关的权重参数,只保留了前向传播网络LeNet的权重参数。符合预期效果。
异步保存CheckPoint
使用方法:CheckpointConfig
类的async_save
参数。
应用场景:训练的模型参数量较大,可以边训练边保存,节省保存CheckPoint文件时的写入时间。
[6]:
config_ck = CheckpointConfig(async_save=True)
ckpoint = ModelCheckpoint(prefix="lenet", config=config_ck)
model.train(epoch_size, ds_train, callbacks=ckpoint)
保存自定义参数字典
使用方法:构造一个obj_dict
传入save_checkpoint
方法。
使用场景:
训练过程中需要额外保存参数(
lr
、epoch_size
等)为CheckPoint文件。修改CheckPoint里面的参数值后重新保存。
把PyTorch、TensorFlow的CheckPoint文件转化为MindSpore的CheckPoint文件。
根据具体场景分为两种情况:
已有CheckPoint文件,修改内容后重新保存。
[7]:
params = load_checkpoint("./lenet-1_1875.ckpt")
# eg: param_list = [{"name": param_name, "data": param_data},...]
param_list = [{"name": k, "data":v} for k,v in params.items()]
print("==========param_list===========\n")
print(param_list)
# del element
del param_list[2]
print("\n==========after delete param_list[2]===========\n")
print(param_list)
# add element "epoch_size"
param = {"name": "epoch_size"}
param["data"] = Tensor(10)
param_list.append(param)
print("\n==========after add element===========\n")
print(param_list)
# modify element
param_list[3]["data"] = Tensor(66)
# save a new checkpoint file
print("\n==========after modify element===========\n")
print(param_list)
save_checkpoint(param_list, 'modify.ckpt')
==========param_list===========
[{'name': 'conv1.weight', 'data': Parameter (name=conv1.weight)}, {'name': 'conv2.weight', 'data': Parameter (name=conv2.weight)}, {'name': 'fc1.weight', 'data': Parameter (name=fc1.weight)}, {'name': 'fc1.bias', 'data': Parameter (name=fc1.bias)}, {'name': 'fc2.weight', 'data': Parameter (name=fc2.weight)}, {'name': 'fc2.bias', 'data': Parameter (name=fc2.bias)}, {'name': 'fc3.weight', 'data': Parameter (name=fc3.weight)}, {'name': 'fc3.bias', 'data': Parameter (name=fc3.bias)}]
==========after delete param_list[2]===========
[{'name': 'conv1.weight', 'data': Parameter (name=conv1.weight)}, {'name': 'conv2.weight', 'data': Parameter (name=conv2.weight)}, {'name': 'fc1.bias', 'data': Parameter (name=fc1.bias)}, {'name': 'fc2.weight', 'data': Parameter (name=fc2.weight)}, {'name': 'fc2.bias', 'data': Parameter (name=fc2.bias)}, {'name': 'fc3.weight', 'data': Parameter (name=fc3.weight)}, {'name': 'fc3.bias', 'data': Parameter (name=fc3.bias)}]
==========after add element===========
[{'name': 'conv1.weight', 'data': Parameter (name=conv1.weight)}, {'name': 'conv2.weight', 'data': Parameter (name=conv2.weight)}, {'name': 'fc1.bias', 'data': Parameter (name=fc1.bias)}, {'name': 'fc2.weight', 'data': Parameter (name=fc2.weight)}, {'name': 'fc2.bias', 'data': Parameter (name=fc2.bias)}, {'name': 'fc3.weight', 'data': Parameter (name=fc3.weight)}, {'name': 'fc3.bias', 'data': Parameter (name=fc3.bias)}, {'name': 'epoch_size', 'data': Tensor(shape=[], dtype=Int64, value= 10)}]
==========after modify element===========
[{'name': 'conv1.weight', 'data': Parameter (name=conv1.weight)}, {'name': 'conv2.weight', 'data': Parameter (name=conv2.weight)}, {'name': 'fc1.bias', 'data': Parameter (name=fc1.bias)}, {'name': 'fc2.weight', 'data': Tensor(shape=[], dtype=Int64, value= 66)}, {'name': 'fc2.bias', 'data': Parameter (name=fc2.bias)}, {'name': 'fc3.weight', 'data': Parameter (name=fc3.weight)}, {'name': 'fc3.bias', 'data': Parameter (name=fc3.bias)}, {'name': 'epoch_size', 'data': Tensor(shape=[], dtype=Int64, value= 10)}]
将加载的模型文件转换成list类型后,可以对模型参数进行删除,添加,修改等操作,并使用save_checkpoint
手动保存,完成对模型权重的内容修改操作。
自定义参数列表保存成CheckPoint文件。
[8]:
param_list = []
# save epoch_size
param = {"name": "epoch_size"}
param["data"] = Tensor(10)
param_list.append(param)
# save learning rate
param = {"name": "learning_rate"}
param["data"] = Tensor(0.01)
param_list.append(param)
# save a new checkpoint file
print(param_list)
save_checkpoint(param_list, 'hyperparameters.ckpt')
[{'name': 'epoch_size', 'data': Tensor(shape=[], dtype=Int64, value= 10)}, {'name': 'learning_rate', 'data': Tensor(shape=[], dtype=Float64, value= 0.01)}]
加载
严格匹配参数名
CheckPoint文件中的权重参数到net
中的时候,会优先匹配net
和CheckPoint中name相同的parameter。匹配完成后,发现net中存在没有加载的parameter,会匹配net中后缀名称与ckpt相同的parameter。
例如:会把CheckPoint中名为conv.0.weight
的参数值加载到net中名为net.conv.0.weight
的parameter中。
如果想取消这种模糊匹配,只采取严格匹配机制,可以通过方法load_param_into_net
中的strict_load
参数控制,默认为False,表示采取模糊匹配机制。
[9]:
net = LeNet5()
params = load_checkpoint("lenet-1_1875.ckpt")
load_param_into_net(net, params, strict_load=True)
print("==========strict load mode===========")
print(params)
==========strict load mode===========
{'conv1.weight': Parameter (name=conv1.weight), 'conv2.weight': Parameter (name=conv2.weight), 'fc1.weight': Parameter (name=fc1.weight), 'fc1.bias': Parameter (name=fc1.bias), 'fc2.weight': Parameter (name=fc2.weight), 'fc2.bias': Parameter (name=fc2.bias), 'fc3.weight': Parameter (name=fc3.weight), 'fc3.bias': Parameter (name=fc3.bias)}
过滤指定前缀
使用方法:load_checkpoint
的filter_prefix
参数。
使用场景:加载CheckPoint时,想要过滤某些包含特定前缀的parameter。
加载CheckPoint时,不加载优化器中的
parameter(eg:filter_prefix=’moments’)
。不加载卷积层的
parameter(eg:filter_prefix=’conv1’)
。
[10]:
net = LeNet5()
print("=============net params=============")
params = load_checkpoint("checkpoint_lenet-1_1875.ckpt")
load_param_into_net(net, params)
print(params)
net = LeNet5()
print("\n=============after filter_prefix moments=============")
params = load_checkpoint("checkpoint_lenet-1_1875.ckpt", filter_prefix='moments')
load_param_into_net(net, params)
print(params)
=============net params=============
{'conv1.weight': Parameter (name=conv1.weight), 'conv2.weight': Parameter (name=conv2.weight), 'fc1.weight': Parameter (name=fc1.weight), 'fc1.bias': Parameter (name=fc1.bias), 'fc2.weight': Parameter (name=fc2.weight), 'fc2.bias': Parameter (name=fc2.bias), 'fc3.weight': Parameter (name=fc3.weight), 'fc3.bias': Parameter (name=fc3.bias), 'learning_rate': Parameter (name=learning_rate), 'momentum': Parameter (name=momentum), 'moments.conv1.weight': Parameter (name=moments.conv1.weight), 'moments.conv2.weight': Parameter (name=moments.conv2.weight), 'moments.fc1.weight': Parameter (name=moments.fc1.weight), 'moments.fc1.bias': Parameter (name=moments.fc1.bias), 'moments.fc2.weight': Parameter (name=moments.fc2.weight), 'moments.fc2.bias': Parameter (name=moments.fc2.bias), 'moments.fc3.weight': Parameter (name=moments.fc3.weight), 'moments.fc3.bias': Parameter (name=moments.fc3.bias)}
=============after filter_prefix moments=============
{'conv1.weight': Parameter (name=conv1.weight), 'conv2.weight': Parameter (name=conv2.weight), 'fc1.weight': Parameter (name=fc1.weight), 'fc1.bias': Parameter (name=fc1.bias), 'fc2.weight': Parameter (name=fc2.weight), 'fc2.bias': Parameter (name=fc2.bias), 'fc3.weight': Parameter (name=fc3.weight), 'fc3.bias': Parameter (name=fc3.bias), 'learning_rate': Parameter (name=learning_rate), 'momentum': Parameter (name=momentum)}
使用过滤前缀的机制,可以将不想载入的参数(本例为优化器权重参数)过滤掉,进行Fine-tune时,可以选用其他的优化器进行优化。
以上为使用MindSpore checkpoint功能的进阶用法,上述所有用法均可共同使用。
转化其他框架CheckPoint为MindSpore的格式
把其他框架的CheckPoint文件转化成MindSpore格式。
一般情况下,CheckPoint文件中保存的就是参数名和参数值,调用相应框架的读取接口后,获取到参数名和数值后,按照MindSpore格式,构建出对象,就可以直接调用MindSpore接口保存成MindSpore格式的CheckPoint文件了。
其中主要的工作量为对比不同框架间的parameter名称,做到两个框架的网络中所有parameter name一一对应(可以使用一个map进行映射),下面代码的逻辑转化parameter格式,不包括对应parameter name。
[11]:
import torch
from mindspore import Tensor, save_checkpoint
def pytorch2mindspore(default_file = 'torch_resnet.pth'):
# read pth file
par_dict = torch.load(default_file)['state_dict']
params_list = []
for name in par_dict:
param_dict = {}
parameter = par_dict[name]
param_dict['name'] = name
param_dict['data'] = Tensor(parameter.numpy())
params_list.append(param_dict)
save_checkpoint(params_list, 'ms_resnet.ckpt')