{ "cells": [ { "cell_type": "markdown", "id": "fac77be9", "metadata": {}, "source": [ "# 学习率与优化器\n", "\n", "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.0.0-alpha/zh_cn/migration_guide/model_development/mindspore_learning_rate_and_optimizer.ipynb)\n", "[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/resource/_static/logo_download_code.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.0.0-alpha/zh_cn/migration_guide/model_development/mindspore_learning_rate_and_optimizer.py)\n", "[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r2.0.0-alpha/docs/mindspore/source_zh_cn/migration_guide/model_development/learning_rate_and_optimizer.ipynb)\n", "\n", "在阅读本章节之前,请先阅读MindSpore官网教程[优化器](https://mindspore.cn/tutorials/zh-CN/r2.0.0-alpha/advanced/modules/optimizer.html)。\n", "\n", "MindSpore官网教程优化器章节的内容已经很详细了,这里就MindSpore的优化器的一些特殊使用方式和学习率衰减策略的原理做一个介绍。\n", "\n", "## 参数分组\n", "\n", "MindSpore的优化器支持一些特别的操作,比如对网络里所有的可训练的参数可以设置不同的学习率(lr)、权重衰减(weight_decay)和梯度中心化(grad_centralization)策略,如:" ] }, { "cell_type": "code", "execution_count": 1, "id": "eaf01165", "metadata": { "ExecuteTime": { "end_time": "2022-09-09T06:46:31.271642Z", "start_time": "2022-09-09T06:46:25.340209Z" } }, "outputs": [], "source": [ "from mindspore import nn\n", "\n", "# 定义模型\n", "class Network(nn.Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.layer1 = nn.SequentialCell([\n", " nn.Conv2d(3, 12, kernel_size=3, pad_mode='pad', padding=1),\n", " nn.BatchNorm2d(12),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=2, stride=2)\n", " ])\n", " self.layer2 = nn.SequentialCell([\n", " nn.Conv2d(12, 4, kernel_size=3, pad_mode='pad', padding=1),\n", " nn.BatchNorm2d(4),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=2, stride=2)\n", " ])\n", " self.pool = nn.AdaptiveMaxPool2d((5, 5))\n", " self.fc = nn.Dense(100, 10)\n", "\n", " def construct(self, x):\n", " x = self.layer1(x)\n", " x = self.layer2(x)\n", " x = self.pool(x)\n", " x = x.view((-1, 100))\n", " out = nn.Dense(x)\n", " return out\n", "\n", "def params_not_in(param, param_list):\n", " # 利用Parameter的id来判断一个param是否不在param_list中\n", " param_id = id(param)\n", " for p in param_list:\n", " if id(p) == param_id:\n", " return False\n", " return True\n", "\n", "net = Network()\n", "trainable_param = net.trainable_params()\n", "conv_weight, bn_weight, dense_weight = [], [], []\n", "for _, cell in net.cells_and_names():\n", " # 判断是什么API,将对应参数加到不同列表里\n", " if isinstance(cell, nn.Conv2d):\n", " conv_weight.append(cell.weight)\n", " elif isinstance(cell, nn.BatchNorm2d):\n", " bn_weight.append(cell.gamma)\n", " bn_weight.append(cell.beta)\n", " elif isinstance(cell, nn.Dense):\n", " dense_weight.append(cell.weight)\n", "\n", "other_param = []\n", "# 所有分组里的参数不能重复,并且其交集是需要做参数更新的所有参数\n", "for param in trainable_param:\n", " if params_not_in(param, conv_weight) and params_not_in(param, bn_weight) and params_not_in(param, dense_weight):\n", " other_param.append(param)\n", "\n", "group_param = [{'order_params': trainable_param}]\n", "# 每一个分组的参数列表不能是空的\n", "\n", "if conv_weight:\n", " conv_weight_lr = nn.cosine_decay_lr(0., 1e-3, total_step=1000, step_per_epoch=100, decay_epoch=10)\n", " group_param.append({'params': conv_weight, 'weight_decay': 1e-4, 'lr': conv_weight_lr})\n", "if bn_weight:\n", " group_param.append({'params': bn_weight, 'weight_decay': 0., 'lr': 1e-4})\n", "if dense_weight:\n", " group_param.append({'params': dense_weight, 'weight_decay': 1e-5, 'lr': 1e-3})\n", "if other_param:\n", " group_param.append({'params': other_param})\n", "\n", "opt = nn.Momentum(group_param, learning_rate=1e-3, weight_decay=0.0, momentum=0.9)" ] }, { "cell_type": "markdown", "id": "369baf31", "metadata": {}, "source": [ "需要注意以下几点:\n", "\n", "1. 每一个分组的参数列表不能是空的;\n", "2. 如果没有设置`weight_decay`和`lr`则使用优化器里设置的值,设置了的话使用分组参数字典里的值;\n", "3. 每个分组里的`lr`都可以是静态或动态的,但不能再分组;\n", "4. 每个分组里的`weight_decay`都需要是符合规范的浮点数;\n", "5. 所有分组里的参数不能重复,并且其交集是需要做参数更新的所有参数。\n", "\n", "## MindSpore的学习率衰减策略\n", "\n", "在训练过程中,MindSpore的学习率是以参数的形式存在于网络里的,在执行优化器更新网络可训练参数前,MindSpore会调用[get_lr](https://www.mindspore.cn/docs/zh-CN/r2.0.0-alpha/api_python/nn/mindspore.nn.Optimizer.html#mindspore.nn.Optimizer.get_lr)\n", "方法获取到当前step需要的学习率的值。\n", "\n", "MindSpore的学习率支持静态、动态、分组三种,其中静态学习率在网络里是一个float32类型的Tensor。\n", "\n", "动态学习率有两种,一种在网络里是一个长度为训练总的step数,float32类型的Tensor,如[Dynamic LR函数](https://www.mindspore.cn/docs/zh-CN/r2.0.0-alpha/api_python/mindspore.nn.html#dynamic-lr%E5%87%BD%E6%95%B0)。在优化器里有一个`global_step`的参数,每经过一次优化器更新参数会+1,MindSpore内部会根据`global_step`和`learning_rate`这两个参数来获取当前step的学习率的值;\n", "另一种是通过构图来生成学习率的值的,如[LearningRateSchedule类](https://www.mindspore.cn/docs/zh-CN/r2.0.0-alpha/api_python/mindspore.nn.html#learningrateschedule%E7%B1%BB)。\n", "\n", "分组学习率如上一小节参数分组中介绍的。\n", "\n", "因为MindSpore的学习率是参数,我们也可以通过给`learning_rate`参数赋值的方式修改训练过程中学习率的值,如[LearningRateScheduler Callback](https://www.mindspore.cn/docs/zh-CN/r2.0.0-alpha/_modules/mindspore/train/callback/_lr_scheduler_callback.html#LearningRateScheduler),这种方法只支持优化器中传入静态的学习率。关键代码如下:" ] }, { "cell_type": "code", "execution_count": 2, "id": "3284e69f", "metadata": { "ExecuteTime": { "end_time": "2022-09-09T06:46:31.316429Z", "start_time": "2022-09-09T06:46:31.274255Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.1\n", "0.01\n" ] } ], "source": [ "import mindspore as ms\n", "from mindspore import ops, nn\n", "\n", "net = nn.Dense(1, 2)\n", "optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)\n", "print(optimizer.learning_rate.data.asnumpy())\n", "new_lr = 0.01\n", "# 改写learning_rate参数的值\n", "ops.assign(optimizer.learning_rate, ms.Tensor(new_lr, ms.float32))\n", "print(optimizer.learning_rate.data.asnumpy())" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 5 }