{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 训练\n", "\n", "[![](https://gitee.com/mindspore/docs/raw/r1.2/docs/programming_guide/source_zh_cn/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.2/docs/programming_guide/source_zh_cn/train.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.2/docs/programming_guide/source_zh_cn/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.2/programming_guide/mindspore_train.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.2/docs/programming_guide/source_zh_cn/_static/logo_modelarts.png)](https://console.huaweicloud.com/modelarts/?region=cn-north-4#/notebook/loading?share-url-b64=aHR0cHM6Ly9vYnMuZHVhbHN0YWNrLmNuLW5vcnRoLTQubXlodWF3ZWljbG91ZC5jb20vbWluZHNwb3JlLXdlYnNpdGUvbm90ZWJvb2svbW9kZWxhcnRzL3Byb2dyYW1taW5nX2d1aWRlL21pbmRzcG9yZV90cmFpbi5pcHluYg==&image_id=65f636a0-56cf-49df-b941-7d2a07ba8c8c)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 概述\n", "\n", "MindSpore在Model_zoo也已经提供了大量的目标检测、自然语言处理等多种网络模型,供用户直接使用,但是对于某些高级用户而言可能想要自行设计网络或者自定义训练循环,下面就对自定义训练网络、自定义训练循环和边训练边推理三种场景进行介绍,另外对On device执行方式进行详细介绍。\n", "\n", "> 本文示例适用于GPU和Ascend环境。\n", "\n", "## 自定义训练网络\n", "\n", "在自定义训练网络前,需要先了解下MindSpore的网络支持、Python源码构造网络约束和算子支持情况。\n", "\n", "- 网络支持:当前MindSpore已经支持多种网络,按类型分为计算机视觉、自然语言处理、推荐和图神经网络,可以通过[网络支持](https://www.mindspore.cn/doc/note/zh-CN/r1.2/network_list.html)查看具体支持的网络情况。如果现有网络无法满足用户需求,用户可以根据实际需要定义自己的网络。\n", "\n", "- Python源码构造网络约束:MindSpore暂不支持将任意Python源码转换成计算图,所以对于用户源码支持的写法有所限制,主要包括语法约束和网络定义约束两方面。详细情况可以查看[静态图语法支持](https://www.mindspore.cn/doc/note/zh-CN/r1.2/static_graph_syntax_support.html)了解。随着MindSpore的演进,这些约束可能会发生变化。\n", "\n", "- 算子支持:顾名思义,网络的基础是算子,所以用户自定义训练网络前要对MindSpore当前支持的算子有所了解,可以通过查看[算子支持](https://www.mindspore.cn/doc/note/zh-CN/r1.2/operator_list.html)了解不同的后端(Ascend、GPU和CPU)的算子实现情况。\n", "\n", "> 当开发网络遇到内置算子不足以满足需求时,用户也可以参考[自定义算子](https://www.mindspore.cn/tutorial/training/zh-CN/r1.2/advanced_use/custom_operator_ascend.html),方便快捷地扩展昇腾AI处理器的自定义算子。\n", "\n", "代码样例如下:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-------loss------ [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. 0. 0. 0. 0. 0.]\n" ] } ], "source": [ "import numpy as np\n", "\n", "from mindspore import Tensor\n", "from mindspore.nn import Cell, Dense, SoftmaxCrossEntropyWithLogits, Momentum, TrainOneStepCell, WithLossCell\n", "import mindspore.ops as ops\n", "\n", "\n", "class ReLUReduceMeanDense(Cell):\n", " def __init__(self, kernel, bias, in_channel, num_class):\n", " super().__init__()\n", " self.relu = ops.ReLU()\n", " self.mean = ops.ReduceMean(keep_dims=False)\n", " self.dense = Dense(in_channel, num_class, kernel, bias)\n", "\n", " def construct(self, x):\n", " x = self.relu(x)\n", " x = self.mean(x, (2, 3))\n", " x = self.dense(x)\n", " return x\n", "\n", "\n", "if __name__ == \"__main__\":\n", " weight_np = np.ones((1000, 2048)).astype(np.float32)\n", " weight = Tensor(weight_np.copy())\n", " bias_np = np.ones((1000,)).astype(np.float32)\n", " bias = Tensor(bias_np.copy())\n", " net = ReLUReduceMeanDense(weight, bias, 2048, 1000)\n", " criterion = SoftmaxCrossEntropyWithLogits(sparse=False)\n", " optimizer = Momentum(learning_rate=0.1, momentum=0.1,\n", " params=filter(lambda x: x.requires_grad, net.get_parameters()))\n", " net_with_criterion = WithLossCell(net, criterion)\n", " train_network = TrainOneStepCell(net_with_criterion, optimizer)\n", " train_network.set_train()\n", " input_np = np.random.randn(32, 2048, 7, 7).astype(np.float32)\n", " input = Tensor(input_np.copy())\n", " label_np_onehot = np.zeros(shape=(32, 1000)).astype(np.float32)\n", " label = Tensor(label_np_onehot.copy())\n", " for i in range(1):\n", " loss = train_network(input, label)\n", " print(\"-------loss------\", loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 自定义训练循环\n", "\n", "在进行自定义循环训练之前,将需要使用的MNIST数据集下载下来,同时解压缩放置指定位置:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "./datasets/MNIST_Data\n", "├── test\n", "│   ├── t10k-images-idx3-ubyte\n", "│   └── t10k-labels-idx1-ubyte\n", "└── train\n", " ├── train-images-idx3-ubyte\n", " └── train-labels-idx1-ubyte\n", "\n", "2 directories, 4 files\n" ] } ], "source": [ "!mkdir -p ./datasets/MNIST_Data/train ./datasets/MNIST_Data/test\n", "!wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-labels-idx1-ubyte \n", "!wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-images-idx3-ubyte\n", "!wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-labels-idx1-ubyte\n", "!wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-images-idx3-ubyte\n", "!tree ./datasets/MNIST_Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "用户如果不想使用MindSpore提供的Model接口,也可参考以下样例自由控制循环的迭代次数、遍历数据集等。\n", "\n", "代码样例如下:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "============== Starting Training ==============\n", "epoch: 1/10, losses: 2.3086986541748047\n", "epoch: 1/10, losses: 2.309938430786133\n", "epoch: 1/10, losses: 2.302298069000244\n", "epoch: 1/10, losses: 2.310209035873413\n", "epoch: 1/10, losses: 2.3002336025238037\n", "epoch: 1/10, losses: 2.3022992610931396\n", "... ...\n", "epoch: 1/10, losses: 0.18848800659179688\n", "epoch: 1/10, losses: 0.15532201528549194\n", "epoch: 2/10, losses: 0.179201140999794\n", "epoch: 2/10, losses: 0.20995387434959412\n", "epoch: 2/10, losses: 0.4867479205131531\n", "... ...\n", "epoch: 10/10, losses: 0.027243722230196\n", "epoch: 10/10, losses: 0.07665436714887619\n", "epoch: 10/10, losses: 0.005962767638266087\n", "epoch: 10/10, losses: 0.026364721357822418\n", "epoch: 10/10, losses: 0.0003102901973761618\n" ] } ], "source": [ "import os\n", "\n", "import mindspore.dataset as ds\n", "import mindspore.dataset.transforms.c_transforms as CT\n", "import mindspore.dataset.vision.c_transforms as CV\n", "import mindspore.nn as nn\n", "from mindspore import context, DatasetHelper, connect_network_with_dataset\n", "from mindspore import dtype as mstype\n", "from mindspore.common.initializer import TruncatedNormal\n", "from mindspore import ParameterTuple\n", "from mindspore.dataset.vision import Inter\n", "from mindspore.nn import WithLossCell\n", "import mindspore.ops as ops\n", "\n", "\n", "def create_dataset(data_path, batch_size=32, repeat_size=1,\n", " num_parallel_workers=1):\n", " \"\"\"\n", " create dataset for train or test\n", " \"\"\"\n", " # define dataset\n", " mnist_ds = ds.MnistDataset(data_path)\n", "\n", " resize_height, resize_width = 32, 32\n", " rescale = 1.0 / 255.0\n", " shift = 0.0\n", " rescale_nml = 1 / 0.3081\n", " shift_nml = -1 * 0.1307 / 0.3081\n", "\n", " # define map operations\n", " resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode\n", " rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)\n", " rescale_op = CV.Rescale(rescale, shift)\n", " hwc2chw_op = CV.HWC2CHW()\n", " type_cast_op = CT.TypeCast(mstype.int32)\n", "\n", " # apply map operations on images\n", " mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n", "\n", " # apply DatasetOps\n", " buffer_size = 10000\n", " mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n", " mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n", " mnist_ds = mnist_ds.repeat(repeat_size)\n", "\n", " return mnist_ds\n", "\n", "\n", "def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):\n", " \"\"\"weight initial for conv layer\"\"\"\n", " weight = weight_variable()\n", " return nn.Conv2d(in_channels, out_channels,\n", " kernel_size=kernel_size, stride=stride, padding=padding,\n", " weight_init=weight, has_bias=False, pad_mode=\"valid\")\n", "\n", "\n", "def fc_with_initialize(input_channels, out_channels):\n", " \"\"\"weight initial for fc layer\"\"\"\n", " weight = weight_variable()\n", " bias = weight_variable()\n", " return nn.Dense(input_channels, out_channels, weight, bias)\n", "\n", "\n", "def weight_variable():\n", " \"\"\"weight initial\"\"\"\n", " return TruncatedNormal(0.02)\n", "\n", "\n", "class LeNet5(nn.Cell):\n", " \"\"\"\n", " Lenet network\n", " Args:\n", " num_class (int): Num classes. Default: 10.\n", "\n", " Returns:\n", " Tensor, output tensor\n", "\n", " Examples:\n", " >>> LeNet(num_class=10)\n", " \"\"\"\n", "\n", " def __init__(self, num_class=10):\n", " super(LeNet5, self).__init__()\n", " self.num_class = num_class\n", " self.batch_size = 32\n", " self.conv1 = conv(1, 6, 5)\n", " self.conv2 = conv(6, 16, 5)\n", " self.fc1 = fc_with_initialize(16 * 5 * 5, 120)\n", " self.fc2 = fc_with_initialize(120, 84)\n", " self.fc3 = fc_with_initialize(84, self.num_class)\n", " self.relu = nn.ReLU()\n", " self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.reshape = ops.Reshape()\n", "\n", " def construct(self, x):\n", " x = self.conv1(x)\n", " x = self.relu(x)\n", " x = self.max_pool2d(x)\n", " x = self.conv2(x)\n", " x = self.relu(x)\n", " x = self.max_pool2d(x)\n", " x = self.reshape(x, (self.batch_size, -1))\n", " x = self.fc1(x)\n", " x = self.relu(x)\n", " x = self.fc2(x)\n", " x = self.relu(x)\n", " x = self.fc3(x)\n", " return x\n", "\n", "\n", "class TrainOneStepCell(nn.Cell):\n", " def __init__(self, network, optimizer, sens=1.0):\n", " super(TrainOneStepCell, self).__init__(auto_prefix=False)\n", " self.network = network\n", " self.weights = ParameterTuple(network.trainable_params())\n", " self.optimizer = optimizer\n", " self.grad = ops.GradOperation(get_by_list=True, sens_param=True)\n", " self.sens = sens\n", "\n", " def set_sens(self, value):\n", " self.sens = value\n", "\n", " def construct(self, data, label):\n", " weights = self.weights\n", " loss = self.network(data, label)\n", " sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)\n", " grads = self.grad(self.network, weights)(data, label, sens)\n", " return ops.depend(loss, self.optimizer(grads))\n", "\n", "\n", "if __name__ == \"__main__\":\n", " context.set_context(mode=context.GRAPH_MODE, device_target=\"GPU\")\n", " \n", " ds_data_path = \"./datasets/MNIST_Data/train/\"\n", " ds_train = create_dataset(ds_data_path, 32)\n", "\n", " network = LeNet5(10)\n", " net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction=\"mean\")\n", " net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)\n", " net = WithLossCell(network, net_loss)\n", " net = TrainOneStepCell(net, net_opt)\n", " network.set_train()\n", " print(\"============== Starting Training ==============\")\n", " epoch = 10\n", " for step in range(epoch):\n", " for inputs in ds_train:\n", " output = net(*inputs)\n", " print(\"epoch: {0}/{1}, losses: {2}\".format(step + 1, epoch, output.asnumpy(), flush=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> 示例中用到的MNIST数据集的获取方法,可以参照[实现一个图片分类应用](https://www.mindspore.cn/tutorial/training/zh-CN/r1.2/quick_start/quick_start.html)的下载数据集部分,下同。\n", ">\n", "> 典型的使用场景是梯度累积,详细查看[梯度累积](https://www.mindspore.cn/tutorial/training/zh-CN/r1.2/advanced_use/apply_gradient_accumulation.html)。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 边训练边推理\n", "\n", "对于某些数据量较大、训练时间较长的复杂网络,为了能掌握训练的不同阶段模型精度的指标变化情况,可以通过边训练边推理的方式跟踪精度的变化情况。具体可以参考[同步训练和验证模型](https://www.mindspore.cn/tutorial/training/zh-CN/r1.2/advanced_use/evaluate_the_model_during_training.html)。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## on-device执行\n", "\n", "当前MindSpore支持的后端包括Ascend、GPU、CPU,所谓On Device中的Device通常指Ascend(昇腾)AI处理器。\n", "\n", "昇腾芯片上集成了AICORE、AICPU和CPU。其中,AICORE负责大型Tensor Vector运算,AICPU负责标量运算,CPU负责逻辑控制和任务分发。\n", "\n", "Host侧CPU负责将图或算子下发到昇腾芯片。昇腾芯片由于具备了运算、逻辑控制和任务分发的功能,所以不需要与Host侧的CPU进行频繁的交互,只需要将计算完的最终结果返回给Host侧,实现整图下沉到Device执行,避免Host-Device频繁交互,减小了开销。\n", "\n", "### 计算图下沉\n", "\n", "计算图整图下沉到Device上执行,减少Host-Device交互开销。可以结合循环下沉实现多个Step下沉,进一步减少Host和Device的交互次数。\n", "\n", "循环下沉是在On Device执行的基础上的优化,目的是进一步减少Host侧和Device侧之间的交互次数。通常情况下,每个step都返回一个结果,循环下沉是控制每隔多少个step返回一次结果。\n", "\n", "默认配置下是每一个epoch返回一次结果,这样每个epoch里,Host侧和Device侧只需要进行一次数据交互。\n", "\n", "也可以结合`train`接口的`dataset_sink_mode`和`sink_size`控制每个epoch的下沉数据量。\n", "\n", "### 数据下沉\n", "\n", "`Model`的`train`接口参数`dataset_sink_mode`可以控制数据是否下沉。`dataset_sink_mode`为True表示数据下沉,否则为非下沉。所谓下沉即数据通过通道直接传送到Device上。\n", "\n", "dataset_sink_mode参数可以配合`sink_size`控制每个`epoch`下沉的数据量大小。当`dataset_sink_mode`设置为True,即数据下沉模式时:\n", "\n", "如果`sink_size`为默认值-1,则每一个`epoch`下沉的数据量为原始的整个数据集大小;\n", "\n", "如果`sink_size`>0,此时原始数据集可以被无限次遍历,每个`epoch`下沉`sink_size`大小的数据量,下一个`epoch`继续从上次遍历的结束位置继续遍历。\n", "\n", "下沉的总数据量由`epoch`和`sink_size`两个变量共同控制,即总数据量=`epoch`*`sink_size`。\n", "\n", "当使用`LossMonitor`,`TimeMonitor`或其它`Callback`接口时,如果`dateset_sink_mode`设置为False,Host侧和Device侧之间每个`step`交互一次,所以会每个`step`返回一个结果,如果`dataset_sink_mode`为True,因为数据在Device上通过通道传输, Host侧和Device侧之间每个`epoch`进行一次数据交互,所以每个`epoch`只返回一次结果。\n", "\n", "> 当前CPU和PyNative模式不支持数据下沉。\n", "\n", "代码样例如下:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "============== Starting Training ==============\n", "epoch: 1 step: 1000, loss is 0.110185064\n", "epoch: 2 step: 1000, loss is 0.12088283\n", "epoch: 3 step: 1000, loss is 0.15903473\n", "epoch: 4 step: 1000, loss is 0.030054657\n", "epoch: 5 step: 1000, loss is 0.013846226\n", "epoch: 6 step: 1000, loss is 0.052161213\n", "epoch: 7 step: 1000, loss is 0.0050197737\n", "epoch: 8 step: 1000, loss is 0.17207858\n", "epoch: 9 step: 1000, loss is 0.010310417\n", "epoch: 10 step: 1000, loss is 0.000672762\n" ] } ], "source": [ "import os\n", "\n", "import mindspore.dataset as ds\n", "import mindspore.dataset.transforms.c_transforms as CT\n", "import mindspore.dataset.vision.c_transforms as CV\n", "import mindspore.nn as nn\n", "from mindspore import context, Model\n", "from mindspore import dtype as mstype\n", "from mindspore.common.initializer import TruncatedNormal\n", "from mindspore.dataset.vision import Inter\n", "from mindspore.nn import Accuracy\n", "import mindspore.ops as ops\n", "from mindspore.train.callback import LossMonitor\n", "\n", "\n", "def create_dataset(data_path, batch_size=32, repeat_size=1,\n", " num_parallel_workers=1):\n", " \"\"\"\n", " create dataset for train or test\n", " \"\"\"\n", " # define dataset\n", " mnist_ds = ds.MnistDataset(data_path)\n", "\n", " resize_height, resize_width = 32, 32\n", " rescale = 1.0 / 255.0\n", " shift = 0.0\n", " rescale_nml = 1 / 0.3081\n", " shift_nml = -1 * 0.1307 / 0.3081\n", "\n", " # define map operations\n", " resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode\n", " rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)\n", " rescale_op = CV.Rescale(rescale, shift)\n", " hwc2chw_op = CV.HWC2CHW()\n", " type_cast_op = CT.TypeCast(mstype.int32)\n", "\n", " # apply map operations on images\n", " mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n", " mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n", "\n", " # apply DatasetOps\n", " buffer_size = 10000\n", " mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script\n", " mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n", " mnist_ds = mnist_ds.repeat(repeat_size)\n", "\n", " return mnist_ds\n", "\n", "\n", "def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):\n", " \"\"\"weight initial for conv layer\"\"\"\n", " weight = weight_variable()\n", " return nn.Conv2d(in_channels, out_channels,\n", " kernel_size=kernel_size, stride=stride, padding=padding,\n", " weight_init=weight, has_bias=False, pad_mode=\"valid\")\n", "\n", "\n", "def fc_with_initialize(input_channels, out_channels):\n", " \"\"\"weight initial for fc layer\"\"\"\n", " weight = weight_variable()\n", " bias = weight_variable()\n", " return nn.Dense(input_channels, out_channels, weight, bias)\n", "\n", "\n", "def weight_variable():\n", " \"\"\"weight initial\"\"\"\n", " return TruncatedNormal(0.02)\n", "\n", "\n", "class LeNet5(nn.Cell):\n", " \"\"\"\n", " Lenet network\n", " Args:\n", " num_class (int): Num classes. Default: 10.\n", "\n", " Returns:\n", " Tensor, output tensor\n", "\n", " Examples:\n", " >>> LeNet(num_class=10)\n", " \"\"\"\n", "\n", " def __init__(self, num_class=10):\n", " super(LeNet5, self).__init__()\n", " self.num_class = num_class\n", " self.batch_size = 32\n", " self.conv1 = conv(1, 6, 5)\n", " self.conv2 = conv(6, 16, 5)\n", " self.fc1 = fc_with_initialize(16 * 5 * 5, 120)\n", " self.fc2 = fc_with_initialize(120, 84)\n", " self.fc3 = fc_with_initialize(84, self.num_class)\n", " self.relu = nn.ReLU()\n", " self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.reshape = ops.Reshape()\n", "\n", " def construct(self, x):\n", " x = self.conv1(x)\n", " x = self.relu(x)\n", " x = self.max_pool2d(x)\n", " x = self.conv2(x)\n", " x = self.relu(x)\n", " x = self.max_pool2d(x)\n", " x = self.reshape(x, (self.batch_size, -1))\n", " x = self.fc1(x)\n", " x = self.relu(x)\n", " x = self.fc2(x)\n", " x = self.relu(x)\n", " x = self.fc3(x)\n", " return x\n", "\n", "\n", "if __name__ == \"__main__\":\n", " context.set_context(mode=context.GRAPH_MODE, device_target=\"GPU\")\n", " ds_train_path = \"./datasets/MNIST_Data/train/\"\n", " ds_train = create_dataset(ds_train_path, 32)\n", "\n", " network = LeNet5(10)\n", " net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction=\"mean\")\n", " net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)\n", " model = Model(network, net_loss, net_opt)\n", "\n", " print(\"============== Starting Training ==============\")\n", " model.train(epoch=10, train_dataset=ds_train, callbacks=[LossMonitor()], dataset_sink_mode=True, sink_size=1000)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`batch_size`为32的情况下,数据集的大小为1875,当`sink_size`设置为1000时,表示每个`epoch`下沉1000个batch的数据,下沉次数为`epoch`=10,下沉的总数据量为:`epoch`*`sink_size`=10000。\n", "\n", "`dataset_sink_mode`为True,所以每个`epoch`返回一次结果。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> `dataset_sink_mode`为False时,`sink_size`参数设置无效。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }