{ "cells": [ { "cell_type": "markdown", "source": [ "# 数据迭代\n", "\n", "`Ascend` `GPU` `CPU` `数据准备`\n", "\n", "[![下载样例代码](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.6/programming_guide/zh_cn/mindspore_dataset_usage.py) [![下载Notebook](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.6/programming_guide/zh_cn/mindspore_dataset_usage.ipynb) [![查看源文件](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.6/docs/mindspore/programming_guide/source_zh_cn/dataset_usage.ipynb)" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## 概述\n", "\n", "原始数据集通过数据集加载接口读取到内存,再通过数据增强操作进行数据变换,得到的数据集对象有两种常规的数据迭代方法:\n", "\n", "- 创建迭代器进行数据迭代。\n", "\n", "- 传入Model接口(如`model.train`、`model.eval`等)进行迭代训练或推理。\n", "\n", "## 创建迭代器进行数据迭代\n", "\n", "数据集对象通常可以创建两种不同的迭代器来遍历数据,分别为元组迭代器和字典迭代器。\n", "\n", "创建元组迭代器的接口为`create_tuple_iterator`,创建字典迭代器的接口为`create_dict_iterator`,具体使用方法如下。\n", "\n", "首先,任意创建一个数据集对象作为演示说明。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 1, "source": [ "import mindspore.dataset as ds\n", "\n", "np_data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]\n", "dataset = ds.NumpySlicesDataset(np_data, column_names=[\"data\"], shuffle=False)" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "则可使用以下方法创建数据迭代器。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 2, "source": [ "# 创建元组迭代器\n", "print(\"\\n create tuple iterator\")\n", "for item in dataset.create_tuple_iterator():\n", " print(\"item:\\n\", item[0])\n", "\n", "# 创建字典迭代器\n", "print(\"\\n create dict iterator\")\n", "for item in dataset.create_dict_iterator():\n", " print(\"item:\\n\", item[\"data\"])\n", "\n", "# 直接遍历数据集对象(等同于创建元组迭代器)\n", "print(\"\\n iterate dataset object directly\")\n", "for item in dataset:\n", " print(\"item:\\n\", item[0])\n", "\n", "# 使用enumerate方式遍历(等同于创建元组迭代器)\n", "print(\"\\n iterate dataset using enumerate\")\n", "for index, item in enumerate(dataset):\n", " print(\"index: {}, item:\\n {}\".format(index, item[0]))" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", " create tuple iterator\n", "item:\n", " [[1 2]\n", " [3 4]]\n", "item:\n", " [[5 6]\n", " [7 8]]\n", "\n", " create dict iterator\n", "item:\n", " [[1 2]\n", " [3 4]]\n", "item:\n", " [[5 6]\n", " [7 8]]\n", "\n", " iterate dataset object directly\n", "item:\n", " [[1 2]\n", " [3 4]]\n", "item:\n", " [[5 6]\n", " [7 8]]\n", "\n", " iterate dataset using enumerate\n", "index: 0, item:\n", " [[1 2]\n", " [3 4]]\n", "index: 1, item:\n", " [[5 6]\n", " [7 8]]\n" ] } ], "metadata": { "ExecuteTime": { "end_time": "2021-09-13T09:01:55.937317Z", "start_time": "2021-09-13T09:01:53.924910Z" } } }, { "cell_type": "markdown", "source": [ "此外,如果需要产生多个Epoch的数据,可以相应地调整入参`num_epochs`的取值。相比于多次调用迭代器接口,直接设置Epoch数可以提高数据迭代的性能。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "# 创建元组迭代器产生2个Epoch的数据\n", "epoch = 2\n", "iterator = dataset.create_tuple_iterator(num_epochs=epoch)\n", "for i in range(epoch):\n", " print(\"epoch: \", i)\n", " for item in iterator:\n", " print(\"item:\\n\", item[0])" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "epoch: 0\n", "item:\n", " [[1 2]\n", " [3 4]]\n", "item:\n", " [[5 6]\n", " [7 8]]\n", "epoch: 1\n", "item:\n", " [[1 2]\n", " [3 4]]\n", "item:\n", " [[5 6]\n", " [7 8]]\n" ] } ], "metadata": { "ExecuteTime": { "end_time": "2021-09-13T09:01:55.951495Z", "start_time": "2021-09-13T09:01:55.938705Z" } } }, { "cell_type": "markdown", "source": [ "迭代器默认输出的数据类型为`mindspore.Tensor`,如果希望得到`numpy.ndarray`类型的数据,可以设置入参`output_numpy=True`。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 4, "source": [ "# 默认输出类型为mindspore.Tensor\n", "for item in dataset.create_tuple_iterator():\n", " print(\"dtype: \", type(item[0]), \"\\nitem:\", item[0])\n", "\n", "# 设置输出类型为numpy.ndarray\n", "for item in dataset.create_tuple_iterator(output_numpy=True):\n", " print(\"dtype: \", type(item[0]), \"\\nitem:\", item[0])" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "dtype: \n", "item: [[1 2]\n", " [3 4]]\n", "dtype: \n", "item: [[5 6]\n", " [7 8]]\n", "dtype: \n", "item: [[1 2]\n", " [3 4]]\n", "dtype: \n", "item: [[5 6]\n", " [7 8]]\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "更详细的说明,请参考[create_tuple_iterator](https://www.mindspore.cn/docs/api/zh-CN/r1.6/api_python/dataset/mindspore.dataset.NumpySlicesDataset.html#mindspore.dataset.NumpySlicesDataset.create_tuple_iterator) 和[create_dict_iterator](https://www.mindspore.cn/docs/api/zh-CN/r1.6/api_python/dataset/mindspore.dataset.NumpySlicesDataset.html#mindspore.dataset.NumpySlicesDataset.create_dict_iterator)的API文档。" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## 传入Model接口进行迭代训练或推理\n", "\n", "数据集对象创建后,可通过传入`Model`接口,由接口内部进行数据迭代,并送入网络执行训练或推理。" ], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "import numpy as np\n", "from mindspore import ms_function\n", "from mindspore import context, nn, Model\n", "import mindspore.dataset as ds\n", "import mindspore.ops as ops\n", "\n", "\n", "def create_dataset():\n", " np_data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]\n", " np_data = np.array(np_data, dtype=np.float16)\n", " dataset = ds.NumpySlicesDataset(np_data, column_names=[\"col1\"], shuffle=False)\n", " return dataset\n", "\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.relu = ops.ReLU()\n", " self.print = ops.Print()\n", "\n", " @ms_function\n", " def construct(self, x):\n", " self.print(x)\n", " return self.relu(x)\n", "\n", "\n", "if __name__ == \"__main__\":\n", " # it is supported to run in CPU, GPU or Ascend\n", " context.set_context(mode=context.GRAPH_MODE)\n", " dataset = create_dataset()\n", " network = Net()\n", " model = Model(network)\n", "\n", " # do training, sink to device defaultly\n", " model.train(epoch=1, train_dataset=dataset, dataset_sink_mode=True)" ], "outputs": [], "metadata": { "ExecuteTime": { "end_time": "2021-09-13T09:01:56.002002Z", "start_time": "2021-09-13T09:01:55.953018Z" } } }, { "cell_type": "markdown", "source": [ "Model接口中的`dataset_sink_mode`参数用于设置是否将数据下沉到Device。若设置为不下沉,则内部会创建上述迭代器,逐条遍历数据并送入网络;若设置为下沉,则内部会将数据直接发送给Device,并送入网络进行迭代训练或推理。\n", "\n", "更加详细的使用方法,可参见[Model基本使用](https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.6/model_use_guide.html#id3)。" ], "metadata": {} } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }