{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 保存、加载与转化模型\n",
    "\n",
    "[![](https://gitee.com/mindspore/docs/raw/r1.2/docs/programming_guide/source_zh_cn/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.2/docs/programming_guide/source_zh_cn/advanced_usage_of_checkpoint.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.2/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.2/programming_guide/mindspore_advanced_usage_of_checkpoint.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.2/docs/programming_guide/source_zh_cn/_static/logo_modelarts.png)](https://console.huaweicloud.com/modelarts/?region=cn-north-4#/notebook/loading?share-url-b64=aHR0cHM6Ly9vYnMuZHVhbHN0YWNrLmNuLW5vcnRoLTQubXlodWF3ZWljbG91ZC5jb20vbWluZHNwb3JlLXdlYnNpdGUvbm90ZWJvb2svbW9kZWxhcnRzL3Byb2dyYW1taW5nX2d1aWRlL21pbmRzcG9yZV9hZHZhbmNlZF91c2FnZV9vZl9jaGVja3BvaW50LmlweW5i&image_id=65f636a0-56cf-49df-b941-7d2a07ba8c8c)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 概述\n",
    "\n",
    "在模型训练或者加载模型的过程中,有时需要替换掉模型文件中某些优化器或者其他超参数以及分类函数中的全连接层改动,但是又不想改动太大,或者从0开始训练模型,针对这种情况,MindSpore提供了只调整模型部分权重的CheckPoint进阶用法,并将方法应用在模型调优过程中。\n",
    "\n",
    "基础用法可参考:[保存加载参数](https://www.mindspore.cn/tutorial/training/zh-CN/r1.2/use/save_model.html#checkpoint)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备工作"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本篇以LeNet网络为例子,介绍在MindSpore中对模型进行保存,加载和转化等操作方法。\n",
    "\n",
    "在进行操作前,需做好如下准备好以下几个文件:\n",
    "\n",
    "- MNIST数据集。\n",
    "- LeNet网络的预训练模型文件`checkpoint-lenet_1-1875.ckpt`。\n",
    "- 数据增强文件`dataset_process.py`,使用其中的数据增强方法`create_dataset`,可参考官网[实现一个图片分类应用](https://www.mindspore.cn/tutorial/training/zh-CN/r1.2/quick_start/quick_start.html)中定义的数据增强方法`create_dataset`。\n",
    "- 定义LeNet网络。\n",
    "\n",
    "执行下述代码,完成前3项准备工作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-19T07:26:35.335788Z",
     "start_time": "2021-03-19T07:26:35.312679Z"
    }
   },
   "outputs": [],
   "source": [
    "!mkdir -p ./datasets/MNIST_Data/train ./datasets/MNIST_Data/test\n",
    "!wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-labels-idx1-ubyte \n",
    "!wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-images-idx3-ubyte\n",
    "!wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-labels-idx1-ubyte\n",
    "!wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-images-idx3-ubyte\n",
    "!wget https://mindspore-website.obs.myhuaweicloud.com/notebook/source-codes/dataset_process.py -N\n",
    "!wget -N https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/checkpoint_lenet-1_1875.zip\n",
    "!unzip -o checkpoint_lenet-1_1875.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义LeNet网络模型,具体定义过程如下。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore.common.initializer import Normal\n",
    "import mindspore.nn as nn\n",
    "\n",
    "class LeNet5(nn.Cell):\n",
    "    \"\"\"Lenet network structure.\"\"\"\n",
    "    # define the operator required\n",
    "    def __init__(self, num_class=10, num_channel=1):\n",
    "        super(LeNet5, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')\n",
    "        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))\n",
    "        self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))\n",
    "        self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))\n",
    "        self.relu = nn.ReLU()\n",
    "        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.flatten = nn.Flatten()\n",
    "\n",
    "    # use the preceding operators to construct networks\n",
    "    def construct(self, x):\n",
    "        x = self.max_pool2d(self.relu(self.conv1(x)))\n",
    "        x = self.max_pool2d(self.relu(self.conv2(x)))\n",
    "        x = self.flatten(x)\n",
    "        x = self.relu(self.fc1(x))\n",
    "        x = self.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 高级用法\n",
    "\n",
    "### 保存\n",
    "\n",
    "#### 手动保存CheckPoint\n",
    "\n",
    "使用`save_checkpoint`,手动保存CheckPoint文件。\n",
    "\n",
    "应用场景: \n",
    "\n",
    "1. 保存网络的初始值。\n",
    "2. 手动保存指定网络。\n",
    "\n",
    "执行以下代码,在对预训练模型`checkpoint_lenet-1_1875.ckpt`训练过100个batch的数据集后,使用`save_checkpoint`手动保存出模型`mindspore_lenet.ckpt`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Parameter (name=conv1.weight), Parameter (name=conv2.weight), Parameter (name=fc1.weight), Parameter (name=fc1.bias), Parameter (name=fc2.weight), Parameter (name=fc2.bias), Parameter (name=fc3.weight), Parameter (name=fc3.bias), Parameter (name=learning_rate), Parameter (name=momentum), Parameter (name=moments.conv1.weight), Parameter (name=moments.conv2.weight), Parameter (name=moments.fc1.weight), Parameter (name=moments.fc1.bias), Parameter (name=moments.fc2.weight), Parameter (name=moments.fc2.bias), Parameter (name=moments.fc3.weight), Parameter (name=moments.fc3.bias)]\n"
     ]
    }
   ],
   "source": [
    "from mindspore import Model, load_checkpoint, save_checkpoint, load_param_into_net\n",
    "from mindspore import context, Tensor\n",
    "from dataset_process import create_dataset\n",
    "import mindspore.nn as nn\n",
    "\n",
    "network = LeNet5()\n",
    "net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)\n",
    "net_loss = nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')\n",
    "\n",
    "params = load_checkpoint(\"checkpoint_lenet-1_1875.ckpt\")\n",
    "load_param_into_net(network, params)\n",
    "\n",
    "net_with_criterion = nn.WithLossCell(network, net_loss)\n",
    "train_net = nn.TrainOneStepCell(net_with_criterion, net_opt)\n",
    "train_net.set_train()\n",
    "\n",
    "train_path = \"./datasets/MNIST_Data/train\"\n",
    "ds_train = create_dataset(train_path)\n",
    "\n",
    "count = 0\n",
    "for item in ds_train.create_dict_iterator():\n",
    "    input_data = item[\"image\"]\n",
    "    labels = item[\"label\"]\n",
    "    train_net(input_data, labels)\n",
    "    count += 1\n",
    "    if count==100:\n",
    "        print(train_net.trainable_params())\n",
    "        save_checkpoint(train_net, \"mindspore_lenet.ckpt\")\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从上述打印信息可以看出`mindspore_lenet.ckpt`的权重参数,包括了前向传播过程中LeNet网络中各隐藏层中的权重参数、学习率、优化率以及反向传播中优化各权重层的优化器函数的权重。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 保存指定的Cell\n",
    "\n",
    "使用方法:`CheckpointConfig`类的`saved_network`参数。\n",
    "\n",
    "应用场景:\n",
    "\n",
    "- 只保存推理网络模型的参数(不保存优化器的参数会使生成的CheckPoint文件大小减小一倍)。\n",
    "\n",
    "- 保存子网的参数,用于Fine-tune(模型微调)任务。\n",
    "\n",
    "在回调函数中使用方法`CheckpointConfig`,并指定保存模型的Cell为`network`即前向传播的LeNet网络。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 step: 625, loss is 0.116291314\n",
      "epoch: 1 step: 1250, loss is 0.09527888\n",
      "epoch: 1 step: 1875, loss is 0.23090823\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor\n",
    "\n",
    "ds_train = create_dataset(train_path)\n",
    "epoch_size = 1\n",
    "model = Model(train_net)\n",
    "config_ck = CheckpointConfig(saved_network=network)\n",
    "ckpoint = ModelCheckpoint(prefix=\"lenet\", config=config_ck)\n",
    "model.train(epoch_size, ds_train, callbacks=[ckpoint, LossMonitor(625)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "模型经过训练后,保存出模型文件`lenet-1_1875.ckpt`。接下来对比指定保存的模型cell和原始模型在大小和具体权重有何区别。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "with_opt size: 482 kB\n",
      "{'conv1.weight': Parameter (name=conv1.weight), 'conv2.weight': Parameter (name=conv2.weight), 'fc1.weight': Parameter (name=fc1.weight), 'fc1.bias': Parameter (name=fc1.bias), 'fc2.weight': Parameter (name=fc2.weight), 'fc2.bias': Parameter (name=fc2.bias), 'fc3.weight': Parameter (name=fc3.weight), 'fc3.bias': Parameter (name=fc3.bias), 'learning_rate': Parameter (name=learning_rate), 'momentum': Parameter (name=momentum), 'moments.conv1.weight': Parameter (name=moments.conv1.weight), 'moments.conv2.weight': Parameter (name=moments.conv2.weight), 'moments.fc1.weight': Parameter (name=moments.fc1.weight), 'moments.fc1.bias': Parameter (name=moments.fc1.bias), 'moments.fc2.weight': Parameter (name=moments.fc2.weight), 'moments.fc2.bias': Parameter (name=moments.fc2.bias), 'moments.fc3.weight': Parameter (name=moments.fc3.weight), 'moments.fc3.bias': Parameter (name=moments.fc3.bias)}\n",
      "\n",
      "=========after train===========\n",
      "\n",
      "without_opt size: 241 kB\n",
      "{'conv1.weight': Parameter (name=conv1.weight), 'conv2.weight': Parameter (name=conv2.weight), 'fc1.weight': Parameter (name=fc1.weight), 'fc1.bias': Parameter (name=fc1.bias), 'fc2.weight': Parameter (name=fc2.weight), 'fc2.bias': Parameter (name=fc2.bias), 'fc3.weight': Parameter (name=fc3.weight), 'fc3.bias': Parameter (name=fc3.bias)}\n"
     ]
    }
   ],
   "source": [
    "model_with_opt = os.path.getsize(\"./checkpoint_lenet-1_1875.ckpt\") // 1024\n",
    "params_without_change = load_checkpoint(\"./checkpoint_lenet-1_1875.ckpt\")\n",
    "print(\"with_opt size:\", model_with_opt, \"kB\")\n",
    "print(params_without_change)\n",
    "\n",
    "print(\"\\n=========after train===========\\n\")\n",
    "model_without_opt = os.path.getsize(\"./lenet-1_1875.ckpt\") // 1024\n",
    "params_with_change = load_checkpoint(\"./lenet-1_1875.ckpt\")\n",
    "print(\"without_opt size:\", model_without_opt, \"kB\")\n",
    "print(params_with_change)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练后,保存出来的模型`lenet-1_1875.ckpt`,模型权重文件大小为241kB,跟原始完整模型大小482kB相比,整体减少了将近一半;\n",
    "\n",
    "具体对比模型中的参数,可以看出`lenet-1_1875.ckpt`中参数相比`checkpoint_lenet-1_1875.ckpt`减少了学习率、优化率和反向优化等相关的权重参数,只保留了前向传播网络LeNet的权重参数。符合预期效果。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 异步保存CheckPoint\n",
    "\n",
    "使用方法:`CheckpointConfig`类的`async_save`参数。\n",
    "\n",
    "应用场景:训练的模型参数量较大,可以边训练边保存,节省保存CheckPoint文件时的写入时间。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_ck = CheckpointConfig(async_save=True)\n",
    "ckpoint = ModelCheckpoint(prefix=\"lenet\", config=config_ck)\n",
    "model.train(epoch_size, ds_train, callbacks=ckpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 保存自定义参数字典\n",
    "\n",
    "使用方法:构造一个`obj_dict`传入`save_checkpoint`方法。\n",
    "\n",
    "使用场景:\n",
    "\n",
    "- 训练过程中需要额外保存参数(`lr`、`epoch_size`等)为CheckPoint文件。\n",
    "\n",
    "- 修改CheckPoint里面的参数值后重新保存。\n",
    "\n",
    "- 把PyTorch、TensorFlow的CheckPoint文件转化为MindSpore的CheckPoint文件。\n",
    "\n",
    "根据具体场景分为两种情况:\n",
    "\n",
    "1. 已有CheckPoint文件,修改内容后重新保存。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==========param_list===========\n",
      "\n",
      "[{'name': 'conv1.weight', 'data': Parameter (name=conv1.weight)}, {'name': 'conv2.weight', 'data': Parameter (name=conv2.weight)}, {'name': 'fc1.weight', 'data': Parameter (name=fc1.weight)}, {'name': 'fc1.bias', 'data': Parameter (name=fc1.bias)}, {'name': 'fc2.weight', 'data': Parameter (name=fc2.weight)}, {'name': 'fc2.bias', 'data': Parameter (name=fc2.bias)}, {'name': 'fc3.weight', 'data': Parameter (name=fc3.weight)}, {'name': 'fc3.bias', 'data': Parameter (name=fc3.bias)}]\n",
      "\n",
      "==========after delete param_list[2]===========\n",
      "\n",
      "[{'name': 'conv1.weight', 'data': Parameter (name=conv1.weight)}, {'name': 'conv2.weight', 'data': Parameter (name=conv2.weight)}, {'name': 'fc1.bias', 'data': Parameter (name=fc1.bias)}, {'name': 'fc2.weight', 'data': Parameter (name=fc2.weight)}, {'name': 'fc2.bias', 'data': Parameter (name=fc2.bias)}, {'name': 'fc3.weight', 'data': Parameter (name=fc3.weight)}, {'name': 'fc3.bias', 'data': Parameter (name=fc3.bias)}]\n",
      "\n",
      "==========after add element===========\n",
      "\n",
      "[{'name': 'conv1.weight', 'data': Parameter (name=conv1.weight)}, {'name': 'conv2.weight', 'data': Parameter (name=conv2.weight)}, {'name': 'fc1.bias', 'data': Parameter (name=fc1.bias)}, {'name': 'fc2.weight', 'data': Parameter (name=fc2.weight)}, {'name': 'fc2.bias', 'data': Parameter (name=fc2.bias)}, {'name': 'fc3.weight', 'data': Parameter (name=fc3.weight)}, {'name': 'fc3.bias', 'data': Parameter (name=fc3.bias)}, {'name': 'epoch_size', 'data': Tensor(shape=[], dtype=Int64, value= 10)}]\n",
      "\n",
      "==========after modify element===========\n",
      "\n",
      "[{'name': 'conv1.weight', 'data': Parameter (name=conv1.weight)}, {'name': 'conv2.weight', 'data': Parameter (name=conv2.weight)}, {'name': 'fc1.bias', 'data': Parameter (name=fc1.bias)}, {'name': 'fc2.weight', 'data': Tensor(shape=[], dtype=Int64, value= 66)}, {'name': 'fc2.bias', 'data': Parameter (name=fc2.bias)}, {'name': 'fc3.weight', 'data': Parameter (name=fc3.weight)}, {'name': 'fc3.bias', 'data': Parameter (name=fc3.bias)}, {'name': 'epoch_size', 'data': Tensor(shape=[], dtype=Int64, value= 10)}]\n"
     ]
    }
   ],
   "source": [
    "params = load_checkpoint(\"./lenet-1_1875.ckpt\")\n",
    "\n",
    "# eg: param_list = [{\"name\": param_name, \"data\": param_data},...]\n",
    "param_list = [{\"name\": k, \"data\":v} for k,v in params.items()]\n",
    "print(\"==========param_list===========\\n\")\n",
    "print(param_list)\n",
    "\n",
    "# del element\n",
    "del param_list[2]\n",
    "print(\"\\n==========after delete param_list[2]===========\\n\")\n",
    "print(param_list)\n",
    "\n",
    "\n",
    "# add element \"epoch_size\"\n",
    "param = {\"name\": \"epoch_size\"}\n",
    "param[\"data\"] = Tensor(10)\n",
    "param_list.append(param)\n",
    "print(\"\\n==========after add element===========\\n\")\n",
    "print(param_list)\n",
    "\n",
    "# modify element\n",
    "param_list[3][\"data\"] = Tensor(66)\n",
    "# save a new checkpoint file\n",
    "print(\"\\n==========after modify element===========\\n\")\n",
    "print(param_list)\n",
    "\n",
    "save_checkpoint(param_list, 'modify.ckpt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将加载的模型文件转换成list类型后,可以对模型参数进行删除,添加,修改等操作,并使用`save_checkpoint`手动保存,完成对模型权重的内容修改操作。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 自定义参数列表保存成CheckPoint文件。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'name': 'epoch_size', 'data': Tensor(shape=[], dtype=Int64, value= 10)}, {'name': 'learning_rate', 'data': Tensor(shape=[], dtype=Float64, value= 0.01)}]\n"
     ]
    }
   ],
   "source": [
    "param_list = []\n",
    "# save epoch_size\n",
    "param = {\"name\": \"epoch_size\"}\n",
    "param[\"data\"] = Tensor(10)\n",
    "param_list.append(param)\n",
    "\n",
    "# save learning rate\n",
    "param = {\"name\": \"learning_rate\"}\n",
    "param[\"data\"] = Tensor(0.01)\n",
    "param_list.append(param)\n",
    "# save a new checkpoint file\n",
    "print(param_list)\n",
    "\n",
    "save_checkpoint(param_list, 'hyperparameters.ckpt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载\n",
    "\n",
    "#### 严格匹配参数名\n",
    "\n",
    "CheckPoint文件中的权重参数到`net`中的时候,会优先匹配`net`和CheckPoint中name相同的parameter。匹配完成后,发现net中存在没有加载的parameter,会匹配net中后缀名称与ckpt相同的parameter。\n",
    "\n",
    "例如:会把CheckPoint中名为`conv.0.weight`的参数值加载到net中名为`net.conv.0.weight`的parameter中。\n",
    "\n",
    "如果想取消这种模糊匹配,只采取严格匹配机制,可以通过方法`load_param_into_net`中的`strict_load`参数控制,默认为False,表示采取模糊匹配机制。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==========strict load mode===========\n",
      "{'conv1.weight': Parameter (name=conv1.weight), 'conv2.weight': Parameter (name=conv2.weight), 'fc1.weight': Parameter (name=fc1.weight), 'fc1.bias': Parameter (name=fc1.bias), 'fc2.weight': Parameter (name=fc2.weight), 'fc2.bias': Parameter (name=fc2.bias), 'fc3.weight': Parameter (name=fc3.weight), 'fc3.bias': Parameter (name=fc3.bias)}\n"
     ]
    }
   ],
   "source": [
    "net = LeNet5()\n",
    "params = load_checkpoint(\"lenet-1_1875.ckpt\")\n",
    "load_param_into_net(net, params, strict_load=True)\n",
    "print(\"==========strict load mode===========\")\n",
    "print(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 过滤指定前缀\n",
    "\n",
    "使用方法:`load_checkpoint`的`filter_prefix`参数。\n",
    "\n",
    "使用场景:加载CheckPoint时,想要过滤某些包含特定前缀的parameter。\n",
    "\n",
    "- 加载CheckPoint时,不加载优化器中的`parameter(eg:filter_prefix=’moments’)`。\n",
    "\n",
    "- 不加载卷积层的`parameter(eg:filter_prefix=’conv1’)`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=============net params=============\n",
      "{'conv1.weight': Parameter (name=conv1.weight), 'conv2.weight': Parameter (name=conv2.weight), 'fc1.weight': Parameter (name=fc1.weight), 'fc1.bias': Parameter (name=fc1.bias), 'fc2.weight': Parameter (name=fc2.weight), 'fc2.bias': Parameter (name=fc2.bias), 'fc3.weight': Parameter (name=fc3.weight), 'fc3.bias': Parameter (name=fc3.bias), 'learning_rate': Parameter (name=learning_rate), 'momentum': Parameter (name=momentum), 'moments.conv1.weight': Parameter (name=moments.conv1.weight), 'moments.conv2.weight': Parameter (name=moments.conv2.weight), 'moments.fc1.weight': Parameter (name=moments.fc1.weight), 'moments.fc1.bias': Parameter (name=moments.fc1.bias), 'moments.fc2.weight': Parameter (name=moments.fc2.weight), 'moments.fc2.bias': Parameter (name=moments.fc2.bias), 'moments.fc3.weight': Parameter (name=moments.fc3.weight), 'moments.fc3.bias': Parameter (name=moments.fc3.bias)}\n",
      "\n",
      "=============after filter_prefix moments=============\n",
      "{'conv1.weight': Parameter (name=conv1.weight), 'conv2.weight': Parameter (name=conv2.weight), 'fc1.weight': Parameter (name=fc1.weight), 'fc1.bias': Parameter (name=fc1.bias), 'fc2.weight': Parameter (name=fc2.weight), 'fc2.bias': Parameter (name=fc2.bias), 'fc3.weight': Parameter (name=fc3.weight), 'fc3.bias': Parameter (name=fc3.bias), 'learning_rate': Parameter (name=learning_rate), 'momentum': Parameter (name=momentum)}\n"
     ]
    }
   ],
   "source": [
    "net = LeNet5()\n",
    "print(\"=============net params=============\")\n",
    "params = load_checkpoint(\"checkpoint_lenet-1_1875.ckpt\")\n",
    "load_param_into_net(net, params)\n",
    "print(params)\n",
    "\n",
    "net = LeNet5()\n",
    "print(\"\\n=============after filter_prefix moments=============\")\n",
    "params = load_checkpoint(\"checkpoint_lenet-1_1875.ckpt\", filter_prefix='moments')\n",
    "load_param_into_net(net, params)\n",
    "print(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用过滤前缀的机制,可以将不想载入的参数(本例为优化器权重参数)过滤掉,进行Fine-tune时,可以选用其他的优化器进行优化。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 以上为使用MindSpore checkpoint功能的进阶用法,上述所有用法均可共同使用。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 转化其他框架CheckPoint为MindSpore的格式\n",
    "\n",
    "把其他框架的CheckPoint文件转化成MindSpore格式。\n",
    "\n",
    "一般情况下,CheckPoint文件中保存的就是参数名和参数值,调用相应框架的读取接口后,获取到参数名和数值后,按照MindSpore格式,构建出对象,就可以直接调用MindSpore接口保存成MindSpore格式的CheckPoint文件了。\n",
    "\n",
    "其中主要的工作量为对比不同框架间的parameter名称,做到两个框架的网络中所有parameter name一一对应(可以使用一个map进行映射),下面代码的逻辑转化parameter格式,不包括对应parameter name。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from mindspore import Tensor, save_checkpoint\n",
    "\n",
    "def pytorch2mindspore(default_file = 'torch_resnet.pth'):\n",
    "    # read pth file\n",
    "    par_dict = torch.load(default_file)['state_dict']\n",
    "    params_list = []\n",
    "    for name in par_dict:\n",
    "        param_dict = {}\n",
    "        parameter = par_dict[name]\n",
    "        param_dict['name'] = name\n",
    "        param_dict['data'] = Tensor(parameter.numpy())\n",
    "        params_list.append(param_dict)\n",
    "    save_checkpoint(params_list,  'ms_resnet.ckpt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore-1.1.1",
   "language": "python",
   "name": "mindspore-1.1.1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}