{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3fca2d6c-be41-47be-b020-228c6e2acc98",
   "metadata": {},
   "source": [
    "[![在线运行](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_modelarts.svg)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9taW5kc3BvcmUtd2Vic2l0ZS5vYnMuY24tbm9ydGgtNC5teWh1YXdlaWNsb3VkLmNvbS9ub3RlYm9vay9yMi4yL3R1dG9yaWFscy96aF9jbi9hZHZhbmNlZC9taW5kc3BvcmVfbWl4ZWRfcHJlY2lzaW9uLmlweW5i&imageid=4c43b3ad-9df7-4b83-a096-c775dc4ba243) [![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/tutorials/zh_cn/advanced/mindspore_mixed_precision.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/tutorials/zh_cn/advanced/mindspore_mixed_precision.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.2/tutorials/source_zh_cn/advanced/mixed_precision.ipynb)\n",
    "\n",
    "# 自动混合精度"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49ed0572-abea-4574-9710-be45d84191b3",
   "metadata": {},
   "source": [
    "混合精度(Mix Precision)训练是指在训练时,对神经网络不同的运算采用不同的数值精度的运算策略。在神经网络运算中,部分运算对数值精度不敏感,此时使用较低精度可以达到明显的加速效果(如conv、matmul等);而部分运算由于输入和输出的数值差异大,通常需要保留较高精度以保证结果的正确性(如log、softmax等)。\n",
    "\n",
    "当前的AI加速卡通常通过针对计算密集、精度不敏感的运算设计了硬件加速模块,如NVIDIA GPU的TensorCore、Ascend NPU的Cube等。对于conv、matmul等运算占比较大的神经网络,其训练速度通常会有较大的加速比。\n",
    "\n",
    "`mindspore.amp`模块提供了便捷的自动混合精度接口,用户可以在不同的硬件后端通过简单的接口调用获得训练加速。下面我们对混合精度计算原理进行简介,而后通过实例介绍MindSpore的自动混合精度用法。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "518888c6-a7ad-4a13-8b7e-8cef642457e3",
   "metadata": {},
   "source": [
    "## 混合精度计算原理"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76ebb0cf-7c2d-49d4-8905-04d928567e76",
   "metadata": {},
   "source": [
    "浮点数据类型主要分为双精度(FP64)、单精度(FP32)、半精度(FP16)。在神经网络模型的训练过程中,一般默认采用单精度(FP32)浮点数据类型,来表示网络模型权重和其他参数。在了解混合精度训练之前,这里简单了解浮点数据类型。\n",
    "\n",
    "根据IEEE二进制浮点数算术标准([IEEE 754](https://en.wikipedia.org/wiki/IEEE_754))的定义,浮点数据类型分为双精度(FP64)、单精度(FP32)、半精度(FP16)三种,其中每一种都有三个不同的位来表示。FP64表示采用8个字节共64位,来进行的编码存储的一种数据类型;同理,FP32表示采用4个字节共32位来表示;FP16则是采用2字节共16位来表示。如图所示:\n",
    "\n",
    "![fp16-vs-fp32](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/tutorials/source_zh_cn/advanced/images/fp16_vs_fp32.png)\n",
    "\n",
    "从图中可以看出,与FP32相比,FP16的存储空间是FP32的一半。类似地,FP32则是FP64的一半。因此使用FP16进行运算具备以下优势:\n",
    "\n",
    "- 减少内存占用:FP16的位宽是FP32的一半,因此权重等参数所占用的内存也是原来的一半,节省下来的内存可以放更大的网络模型或者使用更多的数据进行训练。\n",
    "- 计算效率更高:在特殊的AI加速芯片如华为昇腾系列的Ascend 910和310系列,或者NVIDIA VOLTA架构的GPU上,使用FP16的执行运算性能比FP32更加快。\n",
    "- 加快通讯效率:针对分布式训练,特别是在大模型训练的过程中,通讯的开销制约了网络模型训练的整体性能,通讯的位宽少了意味着可以提升通讯性能,减少等待时间,加快数据的流通。\n",
    "\n",
    "但是使用FP16同样会带来一些问题:\n",
    "\n",
    "- 数据溢出:FP16的有效数据表示范围为 $[5.9\\times10^{-8}, 65504]$,FP32的有效数据表示范围为 $[1.4\\times10^{-45}, 1.7\\times10^{38}]$。可见FP16相比FP32的有效范围要窄很多,使用FP16替换FP32会出现上溢(Overflow)和下溢(Underflow)的情况。而在深度学习中,需要计算网络模型中权重的梯度(一阶导数),因此梯度会比权重值更加小,往往容易出现下溢情况。\n",
    "- 舍入误差:Rounding Error是指当网络模型的反向梯度很小,一般FP32能够表示,但是转换到FP16会小于当前区间内的最小间隔,会导致数据溢出。如`0.00006666666`在FP32中能正常表示,转换到FP16后会表示成为`0.000067`,不满足FP16最小间隔的数会强制舍入。\n",
    "\n",
    "因此,在使用混合精度获得训练加速和内存节省的同时,需要考虑FP16引入问题的解决。Loss Scale损失缩放,FP16类型数据下溢问题的解决方案,其主要思想是在计算损失值loss的时候,将loss扩大一定的倍数。根据链式法则,梯度也会相应扩大,然后在优化器更新权重时再缩小相应的倍数,从而避免了数据下溢。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "981dd3d3-bec1-4b43-a48e-6563547f8411",
   "metadata": {},
   "source": [
    "根据上述原理介绍,典型的混合精度计算流程如下图所示:\n",
    "\n",
    "![mix precision](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/tutorials/experts/source_zh_cn/optimize/images/mix_precision_fp16.png)\n",
    "\n",
    "1. 参数以FP32存储;\n",
    "2. 正向计算过程中,遇到FP16算子,需要把算子输入和参数从FP32 cast成FP16进行计算;\n",
    "3. 将Loss层设置为FP32进行计算;\n",
    "4. 反向计算过程中,首先乘以Loss Scale值,避免反向梯度过小而产生下溢;\n",
    "5. FP16参数参与梯度计算,其结果将被cast回FP32;\n",
    "6. 除以Loss scale值,还原被放大的梯度;\n",
    "7. 判断梯度是否存在溢出,如果溢出则跳过更新,否则优化器以FP32对原始参数进行更新。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81c60159-5242-4ea2-a73e-6015de4a675c",
   "metadata": {},
   "source": [
    "下面我们通过导入[快速入门](https://www.mindspore.cn/tutorials/zh-CN/r2.2/beginner/quick_start.html)中的手写数字识别模型及数据集,演示MindSpore的自动混合精度实现。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0dcf1b26-3050-42e6-ade8-889f3b3dc79d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip (10.3 MB)\n",
      "\n",
      "file_sizes: 100%|██████████████████████████| 10.8M/10.8M [00:07<00:00, 1.53MB/s]\n",
      "Extracting zip file...\n",
      "Successfully downloaded / unzipped to ./\n"
     ]
    }
   ],
   "source": [
    "import mindspore as ms\n",
    "from mindspore import nn\n",
    "from mindspore import value_and_grad\n",
    "from mindspore.dataset import vision, transforms\n",
    "from mindspore.dataset import MnistDataset\n",
    "\n",
    "# Download data from open datasets\n",
    "from download import download\n",
    "\n",
    "url = \"https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/\" \\\n",
    "      \"notebook/datasets/MNIST_Data.zip\"\n",
    "path = download(url, \"./\", kind=\"zip\", replace=True)\n",
    "\n",
    "\n",
    "def datapipe(path, batch_size):\n",
    "    image_transforms = [\n",
    "        vision.Rescale(1.0 / 255.0, 0),\n",
    "        vision.Normalize(mean=(0.1307,), std=(0.3081,)),\n",
    "        vision.HWC2CHW()\n",
    "    ]\n",
    "    label_transform = transforms.TypeCast(ms.int32)\n",
    "\n",
    "    dataset = MnistDataset(path)\n",
    "    dataset = dataset.map(image_transforms, 'image')\n",
    "    dataset = dataset.map(label_transform, 'label')\n",
    "    dataset = dataset.batch(batch_size)\n",
    "    return dataset\n",
    "\n",
    "train_dataset = datapipe('MNIST_Data/train', 64)\n",
    "\n",
    "# Define model\n",
    "class Network(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.dense_relu_sequential = nn.SequentialCell(\n",
    "            nn.Dense(28*28, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 10)\n",
    "        )\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.flatten(x)\n",
    "        logits = self.dense_relu_sequential(x)\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b22c0430-ad94-47b5-bd04-278e2c71775c",
   "metadata": {},
   "source": [
    "## 类型转换\n",
    "\n",
    "混合精度计算需要将需要使用低精度的运算进行类型转换,将其输入转为FP16类型,得到输出后进将其重新转回FP32类型。MindSpore同时提供了自动和手动类型转换的方法,满足对易用性和灵活性的不同需求,下面我们分别对其进行介绍。\n",
    "\n",
    "### 自动类型转换"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bdf54b3",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "`mindspore.amp.auto_mixed_precision` 接口提供对网络做自动类型转换的功能。自动类型转换遵循黑白名单机制,根据常用的运算精度习惯配置了4个等级,分别为:\n",
    "\n",
    "- 'O0':神经网络保持FP32;\n",
    "- 'O1':按白名单将运算cast为FP16;\n",
    "- 'O2':按黑名单保留FP32,其余运算cast为FP16;\n",
    "- 'O3':神经网络完全cast为FP16。\n",
    "\n",
    "下面是使用自动类型转换的示例:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ead4d2ea-95a0-44e6-a722-0701a3e11bfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore.amp import auto_mixed_precision\n",
    "\n",
    "model = Network()\n",
    "model = auto_mixed_precision(model, 'O2')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f89dc3c-e419-484d-8d7c-4e991d85cfd1",
   "metadata": {},
   "source": [
    "### 手动类型转换\n",
    "\n",
    "通常情况下自动类型转换可以通过满足大部分混合精度训练的需求,但当用户需要精细化控制神经网络不同部分的运算精度时,可以通过手动类型转换的方式进行控制。\n",
    "\n",
    "> 手动类型转换需考虑模型各个模块的运算精度,一般仅在需要获得极致性能的情况下使用。\n",
    "\n",
    "下面我们对前文的`Network`进行改造,演示手动类型转换的不同方式。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "064acdb8-8cba-4e60-979d-06e667fc70a3",
   "metadata": {},
   "source": [
    "#### Cell粒度类型转换\n",
    "\n",
    "`nn.Cell`类提供了`to_float`方法,可以一键配置该模块的运算精度,自动将模块输入cast为指定的精度:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "112b57a8-c279-49cb-8069-cc261d92e857",
   "metadata": {},
   "outputs": [],
   "source": [
    "class NetworkFP16(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.dense_relu_sequential = nn.SequentialCell(\n",
    "            nn.Dense(28*28, 512).to_float(ms.float16),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 512).to_float(ms.float16),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 10).to_float(ms.float16)\n",
    "        )\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.flatten(x)\n",
    "        logits = self.dense_relu_sequential(x)\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a169af9b-df4c-4644-8377-833d4cf7100a",
   "metadata": {},
   "source": [
    "#### 自定义粒度类型转换\n",
    "\n",
    "当用户需要在单个运算,或多个模块组合配置运算精度时,Cell粒度往往无法满足,此时可以直接通过对输入数据的类型进行cast来达到自定义粒度控制的目的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d99f7986-bde4-4254-bc97-3baa1c7353d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class NetworkFP16Manual(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.dense_relu_sequential = nn.SequentialCell(\n",
    "            nn.Dense(28*28, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 10)\n",
    "        )\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.flatten(x)\n",
    "        x = x.astype(ms.float16)\n",
    "        logits = self.dense_relu_sequential(x)\n",
    "        logits = logits.astype(ms.float32)\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42d9a6e5-765e-4c28-9ad4-2c1f59996a0b",
   "metadata": {},
   "source": [
    "## 损失缩放\n",
    "\n",
    "MindSpore中提供了两种Loss Scale的实现,分别为`StaticLossScaler`和`DynamicLossScaler`,其差异为损失缩放值scale value是否进行动态调整。下面以`DynamicLossScalar`为例,根据混合精度计算流程实现神经网络训练逻辑。\n",
    "\n",
    "首先,实例化LossScaler,并在定义前向网络时,手动放大loss值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9f3abade",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from mindspore.amp import DynamicLossScaler\n",
    "\n",
    "# Instantiate loss function and optimizer\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = nn.SGD(model.trainable_params(), 1e-2)\n",
    "\n",
    "# Define LossScaler\n",
    "loss_scaler = DynamicLossScaler(scale_value=2**10, scale_factor=2, scale_window=50)\n",
    "\n",
    "def forward_fn(data, label):\n",
    "    logits = model(data)\n",
    "    loss = loss_fn(logits, label)\n",
    "    # scale up the loss value\n",
    "    loss = loss_scaler.scale(loss)\n",
    "    return loss, logits"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09ca284d",
   "metadata": {},
   "source": [
    "接下来进行函数变换,获得梯度函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "95715438-38f9-4476-afc3-5ff655b66966",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "grad_fn = value_and_grad(forward_fn, None, model.trainable_params())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ccec7bf",
   "metadata": {},
   "source": [
    "定义训练step:计算当前梯度值并恢复损失。使用 `all_finite` 判断是否出现梯度下溢问题,如果无溢出,恢复梯度并更新网络权重;如果溢出,跳过此step。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5a3c1cc3",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from mindspore.amp import all_finite\n",
    "\n",
    "@ms.jit\n",
    "def train_step(data, label):\n",
    "    (loss, _), grads = grad_fn(data, label)\n",
    "    loss = loss_scaler.unscale(loss)\n",
    "\n",
    "    is_finite = all_finite(grads)\n",
    "    if is_finite:\n",
    "        grads = loss_scaler.unscale(grads)\n",
    "        optimizer(grads)\n",
    "    loss_scaler.adjust(is_finite)\n",
    "\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70519642-58d2-485a-9614-773e270889dd",
   "metadata": {},
   "source": [
    "最后,我们训练1个epoch,观察使用自动混合精度训练的loss收敛情况。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0bb34af6-b3e2-494b-afbe-4ff68896552c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 2.305425  [  0/938]\n",
      "loss: 2.289585  [100/938]\n",
      "loss: 2.259094  [200/938]\n",
      "loss: 2.176874  [300/938]\n",
      "loss: 1.856715  [400/938]\n",
      "loss: 1.398342  [500/938]\n",
      "loss: 0.889620  [600/938]\n",
      "loss: 0.709884  [700/938]\n",
      "loss: 0.750509  [800/938]\n",
      "loss: 0.482525  [900/938]\n"
     ]
    }
   ],
   "source": [
    "size = train_dataset.get_dataset_size()\n",
    "model.set_train()\n",
    "for batch, (data, label) in enumerate(train_dataset.create_tuple_iterator()):\n",
    "    loss = train_step(data, label)\n",
    "\n",
    "    if batch % 100 == 0:\n",
    "        loss, current = loss.asnumpy(), batch\n",
    "        print(f\"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17ff0777",
   "metadata": {},
   "source": [
    "可以看到loss收敛趋势正常,没有出现溢出问题。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e307279c-1694-44d6-96be-93a7fbb169eb",
   "metadata": {},
   "source": [
    "## `Cell`配置自动混合精度\n",
    "\n",
    "MindSpore支持使用Cell封装完整计算图的编程范式,此时可以使用`mindspore.amp.build_train_network`接口,自动进行类型转换,并将Loss Scale传入,作为整图计算的一部分。\n",
    "此时仅需要配置混合精度等级和`LossScaleManager`即可获得配置好自动混合精度的计算图。\n",
    "\n",
    "`FixedLossScaleManager`和`DynamicLossScaleManager`是`Cell`配置自动混合精度的Loss scale管理接口,分别与`StaticLossScalar`和`DynamicLossScalar`对应,具体详见[mindspore.amp](https://www.mindspore.cn/docs/zh-CN/r2.2/api_python/mindspore.amp.html)。\n",
    "\n",
    "> 使用`Cell`配置自动混合精度训练仅支持`GPU`和`Ascend`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff433a1f-30b5-486b-bfc2-218bca91d85d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore.amp import build_train_network, FixedLossScaleManager\n",
    "\n",
    "model = Network()\n",
    "loss_scale_manager = FixedLossScaleManager()\n",
    "\n",
    "model = build_train_network(model, optimizer, loss_fn, level=\"O2\", loss_scale_manager=loss_scale_manager)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e09b86b5",
   "metadata": {},
   "source": [
    "## `Model` 配置自动混合精度"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5c8adf4-71bb-4061-8801-9a4f97168dc5",
   "metadata": {},
   "source": [
    "`mindspore.train.Model`是神经网络快速训练的高阶封装,其将`mindspore.amp.build_train_network`封装在内,因此同样只需要配置混合精度等级和`LossScaleManager`,即可进行自动混合精度训练。\n",
    "\n",
    "> 使用`Model`配置自动混合精度训练仅支持`GPU`和`Ascend`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e855c91",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 step: 100, loss is 2.2883859\n",
      "epoch: 1 step: 200, loss is 2.2612116\n",
      "epoch: 1 step: 300, loss is 2.1563218\n",
      "epoch: 1 step: 400, loss is 1.9420109\n",
      "epoch: 1 step: 500, loss is 1.396821\n",
      "epoch: 1 step: 600, loss is 1.0450488\n",
      "epoch: 1 step: 700, loss is 0.69754004\n",
      "epoch: 1 step: 800, loss is 0.6924556\n",
      "epoch: 1 step: 900, loss is 0.57444984\n",
      "...\n",
      "epoch: 10 step: 58, loss is 0.13086069\n",
      "epoch: 10 step: 158, loss is 0.07224723\n",
      "epoch: 10 step: 258, loss is 0.08281057\n",
      "epoch: 10 step: 358, loss is 0.09759849\n",
      "epoch: 10 step: 458, loss is 0.17265382\n",
      "epoch: 10 step: 558, loss is 0.10023793\n",
      "epoch: 10 step: 658, loss is 0.08235697\n",
      "epoch: 10 step: 758, loss is 0.10531154\n",
      "epoch: 10 step: 858, loss is 0.19084263\n"
     ]
    }
   ],
   "source": [
    "from mindspore.train import Model, LossMonitor\n",
    "# Initialize network\n",
    "model = Network()\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = nn.SGD(model.trainable_params(), 1e-2)\n",
    "\n",
    "loss_scale_manager = FixedLossScaleManager()\n",
    "trainer = Model(model, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'}, amp_level=\"O2\", loss_scale_manager=loss_scale_manager)\n",
    "\n",
    "loss_callback = LossMonitor(100)\n",
    "trainer.train(10, train_dataset, callbacks=[loss_callback])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e24bdca",
   "metadata": {},
   "source": [
    "> 图片引用自[automatic-mixed-precision](https://developer.nvidia.com/automatic-mixed-precision)。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5 (default, Oct 25 2019, 15:51:11) \n[GCC 7.3.0]"
  },
  "vscode": {
   "interpreter": {
    "hash": "8c9da313289c39257cb28b126d2dadd33153d4da4d524f730c81a4aaccbd2ca7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}