{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 自定义数据集\n", "\n", "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.8/tutorials/zh_cn/advanced/dataset/mindspore_custom.ipynb) \n", "[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.8/tutorials/zh_cn/advanced/dataset/mindspore_custom.py) \n", "[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.8/tutorials/source_zh_cn/advanced/dataset/custom.ipynb)\n", "\n", "[mindspore.dataset](https://www.mindspore.cn/docs/zh-CN/r1.8/api_python/mindspore.dataset.html)提供了部分常用数据集和标准格式数据集的加载接口。对于MindSpore暂不支持直接加载的数据集,可以通过构造自定义数据集类或自定义数据集生成函数的方式来生成数据集,然后通过[mindspore.dataset.GeneratorDataset](https://www.mindspore.cn/docs/zh-CN/r1.8/api_python/dataset/mindspore.dataset.GeneratorDataset.html#mindspore.dataset.GeneratorDataset)接口实现自定义方式的数据集加载。\n", "\n", "通过**自定义数据集类**和**自定义数据集**生成函数两种方式生成的数据集,都可以完成加载、迭代等操作。由于在自定义数据集类中定义了随机访问函数和获取数据集大小函数,因此当需要随机访问数据集中某条数据或获取数据集大小时,使用自定义数据集类生成的数据集可以快速完成这些操作,而通过自定义数据集生成函数的方式生成的数据集需要对数据逐条遍历方可完成这些操作。\n", "\n", "一般情况下,当数据量较小时使用两种生成自定义数据集的方式中的任一种都可以,而当数据量过大时,优先使用自定义数据集类的方式生成数据集。\n", "\n", "本篇我们分别介绍**自定义数据集类**和**自定义数据集**两种生成自定义数据集的方式。\n", "\n", "## 自定义数据集类\n", "\n", "在用户自定义数据集类中须要自定义的类函数如下:\n", "\n", "- `__init__`:定义数据初始化等操作,在实例化数据集对象时被调用。\n", "- `__getitem__`:定义该函数后可使其支持随机访问,能够根据给定的索引值`index`,获取数据集中的数据并返回。数据返回值类型是由NumPy数组组成的Tuple。\n", "- `__len__`:返回数据集的样本数量。\n", "\n", "在完成自定义数据集类之后,可以通过`GeneratorDataset`接口按照用户定义的方式加载并访问数据集样本。下面我们通过两段示例代码来说明使用自定义数据集类的方式生成单标签数据集和多标签数据集的方法。\n", "\n", "### 生成单标签数据集\n", "\n", "通过自定义数据集类`MyDataset`生成五组数据,每组数据由两个随机数组成且只有一个标签。示例代码如下:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "data:[0.41702, 0.72032], label:[0.41919]\n", "data:[0.00011, 0.30233], label:[0.68522]\n", "data:[0.14676, 0.09234], label:[0.20445]\n", "data:[0.18626, 0.34556], label:[0.87812]\n", "data:[0.39677, 0.53882], label:[0.02739]\n", "data size: 5\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.dataset as ds\n", "\n", "np.random.seed(1)\n", "\n", "class MyDataset:\n", " \"\"\"自定义数据集类\"\"\"\n", "\n", " def __init__(self):\n", " \"\"\"自定义初始化操作\"\"\"\n", " self.data = np.random.sample((5, 2)) # 自定义数据\n", " self.label = np.random.sample((5, 1)) # 自定义标签\n", "\n", " def __getitem__(self, index):\n", " \"\"\"自定义随机访问函数\"\"\"\n", " return self.data[index], self.label[index]\n", "\n", " def __len__(self):\n", " \"\"\"自定义获取样本数据量函数\"\"\"\n", " return len(self.data)\n", "\n", "# 实例化数据集类\n", "dataset_generator = MyDataset()\n", "dataset = ds.GeneratorDataset(dataset_generator, [\"data\", \"label\"], shuffle=False)\n", "\n", "# 迭代访问数据集\n", "for data in dataset.create_dict_iterator():\n", " data1 = data['data'].asnumpy()\n", " label1 = data['label'].asnumpy()\n", " print(f'data:[{data1[0]:7.5f}, {data1[1]:7.5f}], label:[{label1[0]:7.5f}]')\n", "\n", "# 打印数据条数\n", "print(\"data size:\", dataset.get_dataset_size())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的打印可以看出,通过自定义数据集类生成的数据集一共有五组,每组数据有一个标签。\n", "\n", "### 生成多标签数据集\n", "\n", "通过自定义数据集类`MyDataset`生成五组数据,每组数据由两个随机数组成且有两个标签`label1`和`lable2`。示例代码如下:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "data:[0.41702, 0.72032] label1:[0.41919] label2:[0.67047]\n", "data:[0.00011, 0.30233] label1:[0.68522] label2:[0.41730]\n", "data:[0.14676, 0.09234] label1:[0.20445] label2:[0.55869]\n", "data:[0.18626, 0.34556] label1:[0.87812] label2:[0.14039]\n", "data:[0.39677, 0.53882] label1:[0.02739] label2:[0.19810]\n", "data size: 5\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.dataset as ds\n", "\n", "np.random.seed(1)\n", "\n", "class MyDataset:\n", " \"\"\"自定义数据集类\"\"\"\n", " def __init__(self):\n", " \"\"\"自定义初始化操作\"\"\"\n", " self.data = np.random.sample((5, 2)) # 自定义数据\n", " self.label1 = np.random.sample((5, 1)) # 自定义标签1\n", " self.label2 = np.random.sample((5, 1)) # 自定义标签2\n", "\n", " def __getitem__(self, index):\n", " \"\"\"自定义随机访问函数\"\"\"\n", " return self.data[index], self.label1[index], self.label2[index]\n", "\n", " def __len__(self):\n", " \"\"\"自定义获取样本数据量函数\"\"\"\n", " return len(self.data)\n", "\n", "# 实例化数据集类\n", "dataset_generator = MyDataset()\n", "# 加载数据集\n", "dataset = ds.GeneratorDataset(dataset_generator, [\"data\", \"label1\", \"label2\"], shuffle=False)\n", "\n", "# 迭代访问数据集\n", "for data in dataset.create_dict_iterator():\n", " print(\"data:[{:7.5f},\".format(data['data'].asnumpy()[0]),\n", " \"{:7.5f}] \".format(data['data'].asnumpy()[1]),\n", " \"label1:[{:7.5f}]\".format(data['label1'].asnumpy()[0]),\n", " \"label2:[{:7.5f}]\".format(data['label2'].asnumpy()[0]))\n", "\n", "# 打印数据条数\n", "print(\"data size:\", dataset.get_dataset_size())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的打印可以看出,通过自定义数据集类生成的数据集一共有五组,每组数据有两个标签。\n", "\n", "## 自定义数据集生成函数\n", "\n", "我们也可以通过使用自定义数据集生成函数的方式生成数据集,之后使用`GeneratorDataset`接口按照用户定义的方式加载并访问数据集样本。\n", "\n", "下面我们通过两段示例代码来说明如何使用自定义数据集生成函数的方式来分别生成单标签数据集和多标签数据集。\n", "\n", "### 生成单标签数据集\n", "\n", "通过自定义数据集生成函数`get_singlelabel_data`生成五组数据,每组数据由一个随机数组成且只有一个标签。示例代码如下:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "data:[ -3.73152] label:[ 0.38465] \n", "data:[ -6.60339] label:[ 0.75629] \n", "data:[ 6.01489] label:[ 0.93652] \n", "data:[ -8.29912] label:[ -0.92189] \n", "data:[ 7.52778] label:[ 0.78921] \n", "data size: 5\n" ] } ], "source": [ "import numpy as np\n", "from mindspore import dataset as ds\n", "\n", "def get_singlelabel_data(num):\n", " \"\"\"定义生成单标签数据集函数\"\"\"\n", " for _ in range(num):\n", " data = np.random.uniform(-10.0, 10.0) # 自定义数据\n", " label = np.random.uniform(-1.0, 1.0) # 自定义标签\n", " yield np.array([data]).astype(np.float32), np.array([label]).astype(np.float32)\n", "\n", "# 定义数据集\n", "dataset = ds.GeneratorDataset(list(get_singlelabel_data(5)), column_names=['data', 'label'])\n", "\n", "# 打印数据集\n", "for data in dataset.create_dict_iterator():\n", " print(\"data:[{:9.5f}] \".format(data['data'].asnumpy()[0]),\n", " \"label:[{:9.5f}] \".format(data['label'].asnumpy()[0]))\n", "\n", "# 打印数据条数\n", "print(\"data size:\", dataset.get_dataset_size())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的打印可以看出,通过自定义数据集生成函数`get_singlelabel_data`生成的数据集一共有五组,每组数据只有一个标签。\n", "\n", "### 生成多标签数据集\n", "\n", "下面我们通过自定义数据集生成函数`get_multilabel_data`来生成一个多标签数据集,每组数据由一个随机数组成且有两个标签。\n", "\n", "示例代码如下:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "data:[ 3.73002] label1:[ 0.66925] label2:[ 1.01829]\n", "data:[ 0.66331] label1:[ 0.38375] label2:[ 1.31552]\n", "data:[ 5.00289] label1:[ 0.97772] label2:[ 1.74817]\n", "data:[ -4.39112] label1:[ 0.57856] label2:[ 1.10323]\n", "data:[ -8.03306] label1:[ -0.15778] label2:[ 1.95789]\n", "data size: 5\n" ] } ], "source": [ "import numpy as np\n", "from mindspore import dataset as ds\n", "\n", "def get_multilabel_data(num, w=2.0, b=3.0):\n", " \"\"\"定义生成多标签数据集函数\"\"\"\n", " for _ in range(num):\n", " data = np.random.uniform(-10.0, 10.0) # 自定义数据\n", " label1 = np.random.uniform(-1.0, 1.0) # 自定义标签1\n", " label2 = np.random.uniform(1.0, 2.0) # 自定义标签2\n", " yield np.array([data]).astype(np.float32), np.array([label1]).astype(np.float32), np.array([label2]).astype(np.float32)\n", "\n", "# 定义数据集\n", "dataset = ds.GeneratorDataset(list(get_multilabel_data(5)), column_names=['data', 'label1', 'label2'])\n", "\n", "# 打印数据集\n", "for data in dataset.create_dict_iterator():\n", " print(\"data:[{:9.5f}] \".format(data['data'].asnumpy()[0]),\n", " \"label1:[{:9.5f}] \".format(data['label1'].asnumpy()[0]),\n", " \"label2:[{:9.5f}]\".format(data['label2'].asnumpy()[0]))\n", "\n", "# 打印数据条数\n", "print(\"data size:\", dataset.get_dataset_size())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的打印可以看出,通过自定义数据集生成函数生成的数据集共五组,每组数据有两个标签。" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }