{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 数据迭代\n", "\n", "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.8/tutorials/zh_cn/advanced/dataset/mindspore_iterator.ipynb) \n", "[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.8/tutorials/zh_cn/advanced/dataset/mindspore_iterator.py) \n", "[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.8/tutorials/source_zh_cn/advanced/dataset/iterator.ipynb)\n", "\n", "原始数据集通过数据集加载接口读取到内存,再通过数据增强操作进行数据变换,最后得到的数据集对象有两种常规的数据迭代方法:\n", "\n", "1. 创建`iterator`迭代器进行数据迭代。\n", "2. 数据直接传入网络模型的Model接口(如`model.train`、`model.eval`等)进行迭代训练或推理。\n", "\n", "## 创建迭代器\n", "\n", "数据集对象通常可以创建两种不同的迭代器来遍历数据,分别为:\n", "\n", "1. **元组迭代器**。创建元组迭代器的接口为`create_tuple_iterator`,通常用于`Model.train`内部使用,其迭代出来的数据可以直接用于训练。\n", "2. **字典迭代器**。创建字典迭代器的接口为`create_dict_iterator`,自定义`train`训练模式下,用户可以根据字典中的`key`进行进一步的数据处理操作,再输入到网络中,使用较为灵活。\n", "\n", "下面通过示例介绍两种迭代器的使用方式:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import mindspore.dataset as ds\n", "\n", "# 数据集\n", "np_data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]\n", "\n", "# 加载数据\n", "dataset = ds.NumpySlicesDataset(np_data, column_names=[\"data\"], shuffle=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "然后使用`create_tuple_iterator`或者`create_dict_iterator`创建数据迭代器。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2021-09-13T09:01:55.937317Z", "start_time": "2021-09-13T09:01:53.924910Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " create tuple iterator\n", "item:\n", " [[1 2]\n", " [3 4]]\n", "item:\n", " [[5 6]\n", " [7 8]]\n", "\n", " create dict iterator\n", "item:\n", " [[1 2]\n", " [3 4]]\n", "item:\n", " [[5 6]\n", " [7 8]]\n", "\n", " iterate dataset object directly\n", "item:\n", " [[1 2]\n", " [3 4]]\n", "item:\n", " [[5 6]\n", " [7 8]]\n", "\n", " iterate dataset using enumerate\n", "index: 0, item:\n", " [[1 2]\n", " [3 4]]\n", "index: 1, item:\n", " [[5 6]\n", " [7 8]]\n" ] } ], "source": [ "# 创建元组迭代器\n", "print(\"\\n create tuple iterator\")\n", "for item in dataset.create_tuple_iterator():\n", " print(\"item:\\n\", item[0])\n", "\n", "# 创建字典迭代器\n", "print(\"\\n create dict iterator\")\n", "for item in dataset.create_dict_iterator():\n", " print(\"item:\\n\", item[\"data\"])\n", "\n", "# 直接遍历数据集对象(等同于创建元组迭代器)\n", "print(\"\\n iterate dataset object directly\")\n", "for item in dataset:\n", " print(\"item:\\n\", item[0])\n", "\n", "# 使用enumerate方式遍历(等同于创建元组迭代器)\n", "print(\"\\n iterate dataset using enumerate\")\n", "for index, item in enumerate(dataset):\n", " print(\"index: {}, item:\\n {}\".format(index, item[0]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "如果需要产生多个epoch的数据,可以相应地调整入参`num_epochs`的取值。相比于多次调用迭代器接口,直接设置epoch数可以提高数据迭代的性能。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2021-09-13T09:01:55.951495Z", "start_time": "2021-09-13T09:01:55.938705Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0\n", "item:\n", " [[1 2]\n", " [3 4]]\n", "item:\n", " [[5 6]\n", " [7 8]]\n", "epoch: 1\n", "item:\n", " [[1 2]\n", " [3 4]]\n", "item:\n", " [[5 6]\n", " [7 8]]\n" ] } ], "source": [ "epoch = 2 # 创建元组迭代器产生2个epoch的数据\n", "\n", "iterator = dataset.create_tuple_iterator(num_epochs=epoch)\n", "\n", "for i in range(epoch):\n", " print(\"epoch: \", i)\n", " for item in iterator:\n", " print(\"item:\\n\", item[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "迭代器默认输出的数据类型为`mindspore.Tensor`,如果希望得到`numpy.ndarray`类型的数据,可以设置入参`output_numpy=True`。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dtype: \n", "item:\n", " [[1 2]\n", " [3 4]]\n", "dtype: \n", "item:\n", " [[5 6]\n", " [7 8]]\n", "dtype: \n", "item:\n", " [[1 2]\n", " [3 4]]\n", "dtype: \n", "item:\n", " [[5 6]\n", " [7 8]]\n" ] } ], "source": [ "# 默认输出类型为mindspore.Tensor\n", "for item in dataset.create_tuple_iterator():\n", " print(\"dtype: \", type(item[0]), \"\\nitem:\\n\", item[0])\n", "\n", "# 设置输出类型为numpy.ndarray\n", "for item in dataset.create_tuple_iterator(output_numpy=True):\n", " print(\"dtype: \", type(item[0]), \"\\nitem:\\n\", item[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 在训练网络时使用迭代器\n", "\n", "下面我们通过一个拟合线性函数的场景,介绍在训练网络时如何使用数据迭代器,线性函数表达式为:\n", "\n", "$$output = {x_0}\\times1 + {x_1}\\times2 + {x_2}\\times3 + ··· + {x_7}\\times8$$\n", "\n", "其函数定义如下:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def func(x):\n", " \"\"\"定义线性函数表达式\"\"\"\n", " result = []\n", " for sample in x:\n", " total = 0\n", " for i, e in enumerate(sample):\n", " total += (i+1) * e\n", " result.append(total)\n", " return result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用上面的线性函数构造自定义训练数据集和验证数据集。在构造自定义训练数据集时需要注意,上述线性函数表达式有8个未知数,把训练数据集的数据带入上述线性函数得出8个线性无关方程,通过解方程即可得出未知数的值。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "class MyTrainData:\n", " \"\"\"自定义训练用数据集类\"\"\"\n", " def __init__(self):\n", " \"\"\"初始化操作\"\"\"\n", " self.__data = np.array([[[1, 1, 1, 1, 1, 1, 1, 1]],\n", " [[1, 1, 1, 1, 1, 1, 1, 0]],\n", " [[1, 1, 1, 1, 1, 1, 0, 0]],\n", " [[1, 1, 1, 1, 1, 0, 0, 0]],\n", " [[1, 1, 1, 1, 0, 0, 0, 0]],\n", " [[1, 1, 1, 0, 0, 0, 0, 0]],\n", " [[1, 1, 0, 0, 0, 0, 0, 0]],\n", " [[1, 0, 0, 0, 0, 0, 0, 0]]]).astype(np.float32)\n", " self.__label = np.array([func(x) for x in self.__data]).astype(np.float32)\n", "\n", " def __getitem__(self, index):\n", " \"\"\"定义随机访问函数\"\"\"\n", " return self.__data[index], self.__label[index]\n", "\n", " def __len__(self):\n", " \"\"\"定义获取数据集大小函数\"\"\"\n", " return len(self.__data)\n", "\n", "class MyEvalData:\n", " \"\"\"自定义验证用数据集类\"\"\"\n", " def __init__(self):\n", " \"\"\"初始化操作\"\"\"\n", " self.__data = np.array([[[1, 2, 3, 4, 5, 6, 7, 8]],\n", " [[1, 1, 1, 1, 1, 1, 1, 1]],\n", " [[8, 7, 6, 5, 4, 3, 2, 1]]]).astype(np.float32)\n", " self.__label = np.array([func(x) for x in self.__data]).astype(np.float32)\n", "\n", " def __getitem__(self, index):\n", " \"\"\"定义随机访问函数\"\"\"\n", " return self.__data[index], self.__label[index]\n", "\n", " def __len__(self):\n", " \"\"\"定义获取数据集大小函数\"\"\"\n", " return len(self.__data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下面我们使用[mindspore.nn.Dense](https://mindspore.cn/docs/zh-CN/r1.8/api_python/nn/mindspore.nn.Dense.html#mindspore.nn.Dense)创建自定义网络,网络的输入为8×1的矩阵。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import mindspore.nn as nn\n", "from mindspore.common.initializer import Normal\n", "\n", "class MyNet(nn.Cell):\n", " \"\"\"自定义网络\"\"\"\n", " def __init__(self):\n", " super(MyNet, self).__init__()\n", " self.fc = nn.Dense(8, 1, weight_init=Normal(0.02))\n", "\n", " def construct(self, x):\n", " x = self.fc(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "自定义网络训练,代码如下:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------- Train ---------\n", "epoch:0, loss: 117.58063\n", "epoch:5, loss: 0.28427964\n", "epoch:10, loss: 0.02881975\n", "epoch:15, loss: 0.050988887\n", "epoch:20, loss: 0.0087212445\n", "epoch:25, loss: 0.040158965\n", "epoch:30, loss: 0.010140566\n", "epoch:35, loss: 0.00040051914\n" ] } ], "source": [ "import mindspore.dataset as ds\n", "import mindspore as ms\n", "from mindspore import amp\n", "\n", "def train(dataset, net, optimizer, loss, epoch):\n", " \"\"\"自定义训练过程\"\"\"\n", " print(\"--------- Train ---------\")\n", "\n", " train_network = amp.build_train_network(net, optimizer, loss)\n", " for i in range(epoch):\n", " # 使用数据迭代器获取数据\n", " for item in dataset.create_dict_iterator():\n", " data = item[\"data\"]\n", " label = item[\"label\"]\n", " loss = train_network(data, label)\n", "\n", " # 每5个epoch打印一次\n", " if i % 5 == 0:\n", " print(\"epoch:{}, loss: {}\".format(i, loss))\n", "\n", "dataset = ds.GeneratorDataset(MyTrainData(), [\"data\", \"label\"], shuffle=True) # 定义数据集\n", "\n", "epoch = 40 # 定义训练轮次\n", "net = MyNet() # 定义网络\n", "loss = nn.MSELoss(reduction=\"mean\") # 定义损失函数\n", "optimizer = nn.Momentum(net.trainable_params(), 0.01, 0.9) # 定义优化器\n", "\n", "# 开始训练\n", "train(dataset, net, optimizer, loss, epoch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的打印结果可以看出,随着训练次数逐渐增多,损失值趋于收敛。接下来我们使用上面训练好的网络进行推理,并打印预测值与目标值。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------- Eval ---------\n", "predict: 203.996, label: 204.000\n", "predict: 36.012, label: 36.000\n", "predict: 116.539, label: 120.000\n" ] } ], "source": [ "def eval(net, data):\n", " \"\"\"自定义推理过程\"\"\"\n", " print(\"--------- Eval ---------\")\n", "\n", " for item in data:\n", " predict = net(ms.Tensor(item[0]))[0]\n", " print(\"predict: {:7.3f}, label: {:7.3f}\".format(predict.asnumpy()[0], item[1][0]))\n", "\n", "# 开始推理\n", "eval(net, MyEvalData())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的打印结果可以看出,推理结果较为准确。\n", "\n", "> 更多关于数据迭代器的使用说明,请参考[create_tuple_iterator](https://www.mindspore.cn/docs/zh-CN/r1.8/api_python/dataset/mindspore.dataset.NumpySlicesDataset.html#mindspore.dataset.NumpySlicesDataset.create_tuple_iterator) 和[create_dict_iterator](https://www.mindspore.cn/docs/zh-CN/r1.8/api_python/dataset/mindspore.dataset.NumpySlicesDataset.html#mindspore.dataset.NumpySlicesDataset.create_dict_iterator)的API文档。\n", "\n", "## 数据迭代训练\n", "\n", "数据集对象创建后,可通过传入`Model`接口,由接口内部进行数据迭代,并送入网络执行训练或推理。实例代码如下:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2021-09-13T09:01:56.002002Z", "start_time": "2021-09-13T09:01:55.953018Z" } }, "outputs": [], "source": [ "import numpy as np\n", "import mindspore as ms\n", "from mindspore import ms_function\n", "from mindspore import nn\n", "import mindspore.dataset as ds\n", "import mindspore.ops as ops\n", "\n", "def create_dataset():\n", " \"\"\"创建自定义数据集\"\"\"\n", " np_data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]\n", " np_data = np.array(np_data, dtype=np.float16)\n", " dataset = ds.NumpySlicesDataset(np_data, column_names=[\"data\"], shuffle=False)\n", " return dataset\n", "\n", "class Net(nn.Cell):\n", " \"\"\"创建一个神经网络\"\"\"\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.relu = ops.ReLU()\n", " self.print = ops.Print()\n", "\n", " @ms_function\n", " def construct(self, x):\n", " self.print(x)\n", " return self.relu(x)\n", "\n", "dataset = create_dataset()\n", "\n", "network = Net()\n", "model = ms.Model(network)\n", "\n", "# 数据集传入model中,train接口进行数据迭代处理\n", "model.train(epoch=1, train_dataset=dataset)" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }