{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 下沉模式\n", "\n", "[![在线运行](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9vYnMuZHVhbHN0YWNrLmNuLW5vcnRoLTQubXlodWF3ZWljbG91ZC5jb20vbWluZHNwb3JlLXdlYnNpdGUvbm90ZWJvb2svcjIuMC90dXRvcmlhbHMvZXhwZXJ0cy96aF9jbi9vcHRpbWl6ZS9taW5kc3BvcmVfZXhlY3V0aW9uX29wdC5pcHluYg==&imageid=b8671c1e-c439-4ae2-b9c6-69b46db134ae) [![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r2.0/tutorials/experts/zh_cn/optimize/mindspore_execution_opt.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r2.0/tutorials/experts/zh_cn/optimize/mindspore_execution_opt.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r2.0/tutorials/experts/source_zh_cn/optimize/execution_opt.ipynb)\n", "\n", "## 概述\n", "\n", "昇腾芯片集成了AICORE和AICPU等计算单元。其中AICORE负责稠密Tensor和Vector运算,AICPU负责复杂控制逻辑的处理。\n", "\n", "为充分发挥昇腾芯片的运算、逻辑控制和任务分发能力,MindSpore提供了数据图下沉、图下沉和循环下沉功能,极大地减少Host-Device交互开销,有效地提升训练与推理的性能。MindSpore的计算图包含网络算子以及算子间的依赖关系。\n", "\n", "从用户的视角来看,网络训练的流程如下:\n", "\n", "![user-view](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/tutorials/experts/source_zh_cn/optimize/images/image-user-view.png)\n", "\n", "本教程以训练的执行流程为例介绍数据下沉、图下沉和循环下沉的原理和使用方法。\n", "\n", "## 数据下沉\n", "\n", "为了提升网络的执行性能,通常使用专用芯片来执行算子,一个芯片对应一个Device,Host与Device的一般交互流程如下:\n", "\n", "![without-sink](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/tutorials/experts/source_zh_cn/optimize/images/image-without-sink.png)\n", "\n", "由上图可见,每个训练迭代都需要从Host拷贝数据到Device,可通过数据下沉消除Host和Device间拷贝输入数据的开销。\n", "\n", "使能数据下沉后,MindSpore会在Device侧创建专门的数据缓存队列,MindSpore数据处理引擎使用高性能数据通道将数据的预处理结果发送到Device的数据队列上,计算图通过GetNext算子直接从数据队列拷贝输入数据,Host向数据队列发送数据和计算图从数据队列读取数据形成流水并行,执行当前迭代的同时可向数据队列发送下一个迭代的数据,从而隐藏了Host-Device数据拷贝的开销,MindSpore高性能数据处理引擎的原理参考[这里](https://www.mindspore.cn/docs/zh-CN/r2.0/design/data_engine.html)。\n", "\n", "GPU后端和昇腾后端都支持数据下沉,GPU数据下沉的Host-Device交互流程如下:\n", "\n", "![data-sink](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/tutorials/experts/source_zh_cn/optimize/images/image-data-sink.png)\n", "\n", "用户可通过[train](https://mindspore.cn/docs/zh-CN/r2.0/api_python/train/mindspore.train.Model.html#mindspore.train.Model.train)接口的`dataset_sink_mode`控制是否使能数据下沉。\n", "\n", "## 图下沉\n", "\n", "一般情况下,每个训练迭代都需要下发并触发device上每个算子的执行,Host与Device交互频繁。\n", "\n", "为减少Host与Device的交互,在图编译时,将网络中的算子打包并一起下发到device,每次迭代只触发一次计算图的执行即可,从而提升网络的执行效率。\n", "\n", "![graph-sink](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/tutorials/experts/source_zh_cn/optimize/images/image-graph-sink.png)\n", "\n", "GPU后端暂不支持图下沉;使用昇腾设备时,开启数据下沉会同时启用图下沉。\n", "\n", "## 循环下沉\n", "\n", "启用数据下沉和图下沉后,每个迭代的计算结果都会返回Host,并由Host判断是否需要进入下一个迭代,为减少每个迭代的Device-Host交互,可以将进入下一个迭代的循环判断下沉到Device,这样等所有迭代执行完成后再将计算结果返回到Host。循环下沉的Host-Device交互流程如下:\n", "\n", "![loop-sink](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/tutorials/experts/source_zh_cn/optimize/images/image-loop-sink.png)\n", "\n", "用户通过[train](https://mindspore.cn/docs/zh-CN/r2.0/api_python/train/mindspore.train.Model.html#mindspore.train.Model.train)接口的`dataset_sink_mode`和`sink_size`参数控制每个epoch的下沉迭代数量,Device侧连续执行`sink_size`个迭代后才返回到Host。\n", "\n", "## 使用方法\n", "\n", "### `Model.train`实现数据下沉\n", "\n", "`Model`的`train`接口参数`dataset_sink_mode`可以控制数据是否下沉。`dataset_sink_mode`为True表示数据下沉,否则为非下沉。所谓下沉即数据通过通道直接传送到Device上。\n", "\n", "`dataset_sink_mode`参数可以配合`sink_size`控制每个`epoch`下沉的数据量大小。当`dataset_sink_mode`设置为True,即数据下沉模式时:\n", "\n", "- 如果`sink_size`为默认值-1,则每一个`epoch`训练整个数据集,理想状态下下沉数据的速度快于硬件计算的速度,保证处理数据的耗时隐藏于网络计算时间内;\n", "\n", "- 如果`sink_size`>0,此时原始数据集可以被无限次遍历,下沉数据流程仍与`sink_size`=-1相同,不同点是每个`epoch`仅训练`sink_size`大小的数据量,如果有`LossMonitor`,那么会训练`sink_size`大小的数据量就打印一次loss值,下一个`epoch`继续从上次遍历的结束位置继续遍历。\n", "\n", "下沉的总数据量由`epoch`和`sink_size`两个变量共同控制,即总数据量=`epoch`*`sink_size`。\n", "\n", "当使用`LossMonitor`、`TimeMonitor`或其它`Callback`接口时,如果`dataset_sink_mode`设置为False,Host侧和Device侧之间每个`step`交互一次,所以会每个`step`返回一个结果,如果`dataset_sink_mode`为True,因为数据在Device上通过通道传输,Host侧和Device侧之间每个`epoch`进行一次数据交互,所以每个`epoch`只返回一次结果。\n", "\n", "> - 当前CPU不支持数据下沉。\n", ">\n", "> - 当设置为GRAPH模式时,每个batch数据的shape必须相同;当设置为PYNATIVE模式时,要求每个batch的size相同。\n", ">\n", "> - 由于数据下沉对数据集的遍历是连续,当前不支持非连续遍历。\n", ">\n", "> - 如果在使用数据下沉模式时,出现`fault kernel_name=GetNext`、`GetNext... task error`或者`outputs = self.get_next()`等类似的错误,那么有可能是数据处理过程中某些样本处理太耗时,导致网络计算侧长时间拿不到数据报错,此时可以将`dataset_sink_mode`设置为False再次验证,或者对数据集使用`create_dict_iterator()`接口单独循环数据集,并参考[数据处理性能优化](https://mindspore.cn/tutorials/experts/zh-CN/r2.0/dataset/optimize.html)调优数据处理,保证数据处理高性能。\n", "\n", "代码样例如下:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The train-labels-idx1-ubyte file is downloaded and saved in the path ./datasets/MNIST_Data/train/ after processing\n", "The train-images-idx3-ubyte file is downloaded and saved in the path ./datasets/MNIST_Data/train/ after processing\n", "============== Starting Training ==============\n", "TotalTime = 1.57468, [16]\n", "[parse]: 0.00476957\n", "[symbol_resolve]: 0.379241, [1]\n", " [Cycle 1]: 0.378255, [1]\n", " [resolve]: 0.378242\n", "[combine_like_graphs]: 0.00174131\n", "[meta_unpack_prepare]: 0.000686213\n", "[abstract_specialize]: 0.0927178\n", "[auto_monad]: 0.00169479\n", "[inline]: 4.96954e-06\n", "[pipeline_split]: 8.78051e-06\n", "[optimize]: 1.05932, [16]\n", "...\n", "epoch: 1 step: 1000, loss is 0.2323482483625412\n", "epoch: 2 step: 1000, loss is 0.1581915020942688\n", "epoch: 3 step: 1000, loss is 0.0452561192214489\n", "epoch: 4 step: 1000, loss is 0.0008174572139978409\n", "epoch: 5 step: 1000, loss is 0.026678290218114853\n", "epoch: 6 step: 1000, loss is 0.24375736713409424\n", "epoch: 7 step: 1000, loss is 0.004280050750821829\n", "epoch: 8 step: 1000, loss is 0.08765432983636856\n", "epoch: 9 step: 1000, loss is 0.06880836188793182\n", "epoch: 10 step: 1000, loss is 0.05223526805639267\n" ] } ], "source": [ "import os\n", "import requests\n", "import mindspore.dataset as ds\n", "import mindspore as ms\n", "import mindspore.dataset.transforms as transforms\n", "import mindspore.dataset.vision as vision\n", "import mindspore.nn as nn\n", "from mindspore import train\n", "from mindspore.common.initializer import TruncatedNormal\n", "from mindspore.dataset.vision import Inter\n", "import mindspore.ops as ops\n", "\n", "requests.packages.urllib3.disable_warnings()\n", "\n", "def create_dataset(data_path, batch_size=32, repeat_size=1,\n", " num_parallel_workers=1):\n", " \"\"\"\n", " create dataset for train or test\n", " \"\"\"\n", " # define dataset\n", " mnist_ds = ds.MnistDataset(data_path)\n", "\n", " resize_height, resize_width = 32, 32\n", " rescale = 1.0 / 255.0\n", " shift = 0.0\n", " rescale_nml = 1 / 0.3081\n", " shift_nml = -1 * 0.1307 / 0.3081\n", "\n", " # define map operations\n", " resize_op = vision.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode\n", " rescale_nml_op = vision.Rescale(rescale_nml, shift_nml)\n", " rescale_op = vision.Rescale(rescale, shift)\n", " hwc2chw_op = vision.HWC2CHW()\n", " type_cast_op = transforms.TypeCast(ms.int32)\n", "\n", " # apply map operations on images\n", " mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n", "\n", " # apply DatasetOps\n", " buffer_size = 10000\n", " mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n", " mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n", " mnist_ds = mnist_ds.repeat(repeat_size)\n", "\n", " return mnist_ds\n", "\n", "\n", "def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):\n", " \"\"\"weight initial for conv layer\"\"\"\n", " weight = weight_variable()\n", " return nn.Conv2d(in_channels, out_channels,\n", " kernel_size=kernel_size, stride=stride, padding=padding,\n", " weight_init=weight, has_bias=False, pad_mode=\"valid\")\n", "\n", "\n", "def fc_with_initialize(input_channels, out_channels):\n", " \"\"\"weight initial for fc layer\"\"\"\n", " weight = weight_variable()\n", " bias = weight_variable()\n", " return nn.Dense(input_channels, out_channels, weight, bias)\n", "\n", "\n", "def weight_variable():\n", " \"\"\"weight initial\"\"\"\n", " return TruncatedNormal(0.02)\n", "\n", "\n", "class LeNet5(nn.Cell):\n", " \"\"\"\n", " Lenet network\n", " Args:\n", " num_class (int): Num classes. Default: 10.\n", "\n", " Returns:\n", " Tensor, output tensor\n", "\n", " Examples:\n", " >>> LeNet(num_class=10)\n", " \"\"\"\n", "\n", " def __init__(self, num_class=10):\n", " super(LeNet5, self).__init__()\n", " self.num_class = num_class\n", " self.batch_size = 32\n", " self.conv1 = conv(1, 6, 5)\n", " self.conv2 = conv(6, 16, 5)\n", " self.fc1 = fc_with_initialize(16 * 5 * 5, 120)\n", " self.fc2 = fc_with_initialize(120, 84)\n", " self.fc3 = fc_with_initialize(84, self.num_class)\n", " self.relu = nn.ReLU()\n", " self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n", "\n", " def construct(self, x):\n", " x = self.conv1(x)\n", " x = self.relu(x)\n", " x = self.max_pool2d(x)\n", " x = self.conv2(x)\n", " x = self.relu(x)\n", " x = self.max_pool2d(x)\n", " x = ops.reshape(x, (self.batch_size, -1))\n", " x = self.fc1(x)\n", " x = self.relu(x)\n", " x = self.fc2(x)\n", " x = self.relu(x)\n", " x = self.fc3(x)\n", " return x\n", "\n", "def download_dataset(dataset_url, path):\n", " filename = dataset_url.split(\"/\")[-1]\n", " save_path = os.path.join(path, filename)\n", " if os.path.exists(save_path):\n", " return\n", " if not os.path.exists(path):\n", " os.makedirs(path)\n", " res = requests.get(dataset_url, stream=True, verify=False)\n", " with open(save_path, \"wb\") as f:\n", " for chunk in res.iter_content(chunk_size=512):\n", " if chunk:\n", " f.write(chunk)\n", " print(\"The {} file is downloaded and saved in the path {} after processing\".format(os.path.basename(dataset_url), path))\n", "\n", "\n", "if __name__ == \"__main__\":\n", " ms.set_context(mode=ms.GRAPH_MODE, device_target=\"GPU\")\n", " ds_train_path = \"./datasets/MNIST_Data/train/\"\n", " download_dataset(\"https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-labels-idx1-ubyte\", ds_train_path)\n", " download_dataset(\"https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-images-idx3-ubyte\", ds_train_path)\n", " ds_train = create_dataset(ds_train_path, 32)\n", "\n", " network = LeNet5(10)\n", " net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction=\"mean\")\n", " net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)\n", " model = train.Model(network, net_loss, net_opt)\n", "\n", " print(\"============== Starting Training ==============\")\n", " model.train(epoch=10, train_dataset=ds_train, callbacks=[train.LossMonitor()], dataset_sink_mode=True, sink_size=1000)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "batch_size为32的情况下,数据集的大小为1875,当sink_size设置为1000时,表示每个epoch下沉1000个batch的数据,下沉次数为epoch=10,下沉的总数据量为:epoch*sink_size=10000。\n", "\n", "dataset_sink_mode为True,所以每个epoch返回一次结果。训练过程中使用DatasetHelper进行数据集的迭代及数据信息的管理。如果为下沉模式,使用 mindspore.connect_network_with_dataset 函数连接当前的训练网络或评估网络 network 和 DatasetHelper,此函数使用 mindspore.ops.GetNext 包装输入网络,以实现在前向计算时,在设备(Device)侧从对应名称为 queue_name 的数据通道中获取数据,并将数据传递到输入网络。如果为非下沉模式,则在主机(Host)直接遍历数据集获取数据。\n", "\n", "dataset_sink_mode为False时,sink_size参数设置无效。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `data_sink`实现数据下沉\n", "\n", "在MindSpore的函数式编程范式中,还可以使用[data_sink接口](https://mindspore.cn/docs/zh-CN/r2.0/api_python/mindspore/mindspore.data_sink.html)将模型的执行函数和数据集绑定,实现数据下沉。参数含义如下:\n", "\n", "- `fn`:下沉模型的执行函数;\n", "- `dataset`:数据集,可以由[mindspore.dataset](https://mindspore.cn/docs/zh-CN/r2.0/api_python/mindspore.dataset.html)生成;\n", "- `sink_size`:用来调整每次下沉执行的数据量,可以指定为任意正数,默认值为1,即每次下沉只执行一个step的数据。如需单次下沉执行整个epoch的数据,可以使用`dataset`的`get_datasize_size()`方法来指定其值。也可以单次下沉多个epoch,设置其值为`epoch * get_datasize_size()`。(多次`data_sink`的调用对数据集是连续遍历的,下一次调用是从上一次调用结束位置后继续遍历)\n", "- `jit_config`:编译时所使用的JitConfig配置项,详细可参考[mindspore.JitConfig](https://mindspore.cn/docs/zh-CN/r2.0/api_python/mindspore/mindspore.JitConfig.html#mindspore.JitConfig)。默认值:None,表示以PyNative模式运行。\n", "- `input_signature`:用于表示输入参数的Tensor。Tensor的shape和dtype将作为函数的输入shape和dtype。默认值:None。\n", "\n", "> - 当前CPU不支持数据下沉。\n", ">\n", "> - 当设置为GRAPH模式时,每个batch数据的shape必须相同;当设置为PYNATIVE模式时,要求每个batch的size相同。\n", ">\n", "> - 由于数据下沉对数据集的遍历是连续,当前不支持非连续遍历。\n", ">\n", "> - 如果在使用数据下沉模式时,出现`fault kernel_name=GetNext`、`GetNext... task error`或者`outputs = self.get_next()`等类似的错误,那么有可能是数据处理过程中某些样本处理太耗时,导致网络计算侧长时间拿不到数据报错,此时可以将`dataset_sink_mode`设置为False再次验证,或者对数据集使用`create_dict_iterator()`接口单独循环数据集,并参考[数据处理性能优化](https://mindspore.cn/tutorials/experts/zh-CN/r2.0/dataset/optimize.html)调优数据处理,保证数据处理高性能。\n", "\n", "代码示例如下:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import requests\n", "import mindspore.dataset as ds\n", "import mindspore as ms\n", "import mindspore.dataset.transforms as transforms\n", "import mindspore.dataset.vision as vision\n", "import mindspore.nn as nn\n", "from mindspore.common.initializer import TruncatedNormal\n", "from mindspore.dataset.vision import Inter\n", "import mindspore.ops as ops\n", "\n", "requests.packages.urllib3.disable_warnings()\n", "\n", "def create_dataset(data_path, batch_size=32, repeat_size=1,\n", " num_parallel_workers=1):\n", " \"\"\"\n", " create dataset for train or test\n", " \"\"\"\n", " # define dataset\n", " mnist_ds = ds.MnistDataset(data_path)\n", "\n", " resize_height, resize_width = 32, 32\n", " rescale = 1.0 / 255.0\n", " shift = 0.0\n", " rescale_nml = 1 / 0.3081\n", " shift_nml = -1 * 0.1307 / 0.3081\n", "\n", " # define map operations\n", " resize_op = vision.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode\n", " rescale_nml_op = vision.Rescale(rescale_nml, shift_nml)\n", " rescale_op = vision.Rescale(rescale, shift)\n", " hwc2chw_op = vision.HWC2CHW()\n", " type_cast_op = transforms.TypeCast(ms.int32)\n", "\n", " # apply map operations on images\n", " mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n", "\n", " # apply DatasetOps\n", " buffer_size = 10000\n", " mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n", " mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n", " mnist_ds = mnist_ds.repeat(repeat_size)\n", "\n", " return mnist_ds\n", "\n", "\n", "def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):\n", " \"\"\"weight initial for conv layer\"\"\"\n", " weight = weight_variable()\n", " return nn.Conv2d(in_channels, out_channels,\n", " kernel_size=kernel_size, stride=stride, padding=padding,\n", " weight_init=weight, has_bias=False, pad_mode=\"valid\")\n", "\n", "\n", "def fc_with_initialize(input_channels, out_channels):\n", " \"\"\"weight initial for fc layer\"\"\"\n", " weight = weight_variable()\n", " bias = weight_variable()\n", " return nn.Dense(input_channels, out_channels, weight, bias)\n", "\n", "\n", "def weight_variable():\n", " \"\"\"weight initial\"\"\"\n", " return TruncatedNormal(0.02)\n", "\n", "\n", "class LeNet5(nn.Cell):\n", " \"\"\"\n", " Lenet network\n", " Args:\n", " num_class (int): Num classes. Default: 10.\n", "\n", " Returns:\n", " Tensor, output tensor\n", "\n", " Examples:\n", " >>> LeNet(num_class=10)\n", " \"\"\"\n", "\n", " def __init__(self, num_class=10):\n", " super(LeNet5, self).__init__()\n", " self.num_class = num_class\n", " self.batch_size = 32\n", " self.conv1 = conv(1, 6, 5)\n", " self.conv2 = conv(6, 16, 5)\n", " self.fc1 = fc_with_initialize(16 * 5 * 5, 120)\n", " self.fc2 = fc_with_initialize(120, 84)\n", " self.fc3 = fc_with_initialize(84, self.num_class)\n", " self.relu = nn.ReLU()\n", " self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n", "\n", " def construct(self, x):\n", " x = self.conv1(x)\n", " x = self.relu(x)\n", " x = self.max_pool2d(x)\n", " x = self.conv2(x)\n", " x = self.relu(x)\n", " x = self.max_pool2d(x)\n", " x = ops.reshape(x, (self.batch_size, -1))\n", " x = self.fc1(x)\n", " x = self.relu(x)\n", " x = self.fc2(x)\n", " x = self.relu(x)\n", " x = self.fc3(x)\n", " return x\n", "\n", "def download_dataset(dataset_url, path):\n", " filename = dataset_url.split(\"/\")[-1]\n", " save_path = os.path.join(path, filename)\n", " if os.path.exists(save_path):\n", " return\n", " if not os.path.exists(path):\n", " os.makedirs(path)\n", " res = requests.get(dataset_url, stream=True, verify=False)\n", " with open(save_path, \"wb\") as f:\n", " for chunk in res.iter_content(chunk_size=512):\n", " if chunk:\n", " f.write(chunk)\n", " print(\"The {} file is downloaded and saved in the path {} after processing\".format(os.path.basename(dataset_url), path))\n", "\n", "if __name__ == \"__main__\":\n", " ms.set_context(mode=ms.GRAPH_MODE, device_target=\"GPU\")\n", " ds_train_path = \"./datasets/MNIST_Data/train/\"\n", " download_dataset(\"https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-labels-idx1-ubyte\", ds_train_path)\n", " download_dataset(\"https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-images-idx3-ubyte\", ds_train_path)\n", "\n", " network = LeNet5(10)\n", " network.set_train()\n", " weights = network.trainable_params()\n", " net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction=\"mean\")\n", " net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)\n", "\n", " def forward_fn(data, label):\n", " loss = net_loss(network(data), label)\n", " return loss\n", "\n", " grad_fn = ms.value_and_grad(forward_fn, None, weights)\n", "\n", " def train_step(data, label):\n", " loss, grads = grad_fn(data, label)\n", " net_opt(grads)\n", " return loss\n", "\n", " print(\"============== Different calling methods train 10 epochs ==============\")\n", " jit = ms.JitConfig()\n", " print(\"1. Default, execute one step data each per sink\")\n", " ds_train = create_dataset(ds_train_path, 32)\n", " data_size = ds_train.get_dataset_size()\n", " epochs = 10\n", " sink_process = ms.data_sink(train_step, ds_train, jit_config=jit)\n", " for _ in range(data_size * epochs):\n", " loss = sink_process()\n", " print(f\"step {_ + 1}, loss is {loss}\")\n", "\n", " print(\"2. Execute one epoch data per sink\")\n", " ds_train = create_dataset(ds_train_path, 32)\n", " data_size = ds_train.get_dataset_size()\n", " epochs = 10\n", " sink_process = ms.data_sink(train_step, ds_train, sink_size=data_size, jit_config=jit)\n", " for _ in range(epochs):\n", " loss = sink_process()\n", " print(f\"epoch {_ + 1}, loss is {loss}\")\n", "\n", " print(\"3. Execute multiple epoch data per sink\")\n", " ds_train = create_dataset(ds_train_path, 32)\n", " data_size = ds_train.get_dataset_size()\n", " epochs = 10\n", " sink_process = ms.data_sink(train_step, ds_train, sink_size=epochs*data_size, jit_config=jit)\n", " loss = sink_process()\n", " print(f\"loss is {loss}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "代码中分别使用3种调用方式训练10个epoch。\n", "\n", "1. 默认行为,每次下沉1个step的数据,每个step结束后返回loss,训练10个epoch需要在Host侧循环调用`ds_train.get_dataset_size() * 10`次;\n", "2. 每次下沉1个epoch的数据,每个epoch执行结束后返回loss,训练10个epoch需要在Host侧循环调用10次;\n", "3. 单次下沉10个epoch的数据,10个epoch执行结束后返回loss,无需在Host侧进行循环。\n", "\n", "上述方法中,方法1在每个step结束后与Device进行一次交互,效率较低;方法3在训练中不需要与Device进行交互,执行效率最高,但只能返回最后的一个step的loss。" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "notebook_metadata_filter": "-all", "text_representation": { "extension": ".md", "format_name": "markdown" } }, "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.15" }, "vscode": { "interpreter": { "hash": "dd229a603a0b59691674b90e6cb53b3f51b627d2c287521f1174e0b9cbbabb93" } } }, "nbformat": 4, "nbformat_minor": 4 }