{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "[](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.4.1/tutorials/zh_cn/beginner/mindspore_dataset.ipynb) [](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.4.1/tutorials/zh_cn/beginner/mindspore_dataset.py) [](https://gitee.com/mindspore/docs/blob/r2.4.1/tutorials/source_zh_cn/beginner/dataset.ipynb)\n", "\n", "[基本介绍](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/introduction.html) || [快速入门](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/quick_start.html) || [张量 Tensor](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/tensor.html) || **数据加载与处理** || [网络构建](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/model.html) || [函数式自动微分](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/autograd.html) || [模型训练](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/train.html) || [保存与加载](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/save_load.html) || [使用静态图加速](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/accelerate_with_static_graph.html) || [自动混合精度](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/mixed_precision.html) ||" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# 数据加载与处理\n", "\n", "数据是深度学习的基础,高质量的数据输入将在整个深度神经网络中起到积极作用。\n", "\n", "MindSpore提供基于Pipeline的[数据引擎](https://www.mindspore.cn/docs/zh-CN/r2.4.1/design/data_engine.html),通过 `数据集(Dataset)`、`数据变换(Transforms)`和`数据batch`实现高效的数据预处理。其中:\n", "\n", "1. 数据集(Dataset)是Pipeline的起始,用于从存储中加载原始数据至内存中,`mindspore.dataset`提供了内置的图像、文本、音频等[数据集加载接口](https://www.mindspore.cn/docs/zh-CN/r2.4.1/api_python/mindspore.dataset.html#),并提供了[自定义数据集加载接口](https://www.mindspore.cn/docs/zh-CN/r2.4.1/api_python/mindspore.dataset.html#%E7%94%A8%E6%88%B7%E8%87%AA%E5%AE%9A%E4%B9%89);\n", "\n", "2. 数据变换(Transforms)对内存中的数据做进一步的变换操作,`mindspore.dataset.transforms`提供[通用的数据变换操作](https://www.mindspore.cn/docs/zh-CN/r2.4.1/api_python/mindspore.dataset.transforms.html#%E9%80%9A%E7%94%A8)、`mindspore.dataset.transforms.vision`提供[图像数据变换操作](https://www.mindspore.cn/docs/zh-CN/r2.4.1/api_python/mindspore.dataset.transforms.html#%E8%A7%86%E8%A7%89)、`mindspore.dataset.transforms.text`提供[文本数据变换操作](https://www.mindspore.cn/docs/zh-CN/r2.4.1/api_python/mindspore.dataset.transforms.html#%E6%96%87%E6%9C%AC)、`mindspore.dataset.transforms.audio`提供[音频数据变换操作](https://www.mindspore.cn/docs/zh-CN/r2.4.1/api_python/mindspore.dataset.transforms.html#%E9%9F%B3%E9%A2%91);\n", "\n", "3. 数据batch完成对变换后的数据组batch,用于最终的神经网络训练,batch操作是针对一个数据集对象,其接口可参考:[batch操作](https://www.mindspore.cn/docs/zh-CN/r2.4.1/api_python/dataset/mindspore.dataset.MindDataset.html#batch%E6%89%B9%E6%93%8D%E4%BD%9C);\n", "\n", "4. 数据集迭代器是将最后的数据通过迭代的方式输出,迭代器也是针对一个数据集对象,其接口可参考:[迭代器](https://www.mindspore.cn/docs/zh-CN/r2.4.1/api_python/dataset/mindspore.dataset.MindDataset.html#%E8%BF%AD%E4%BB%A3%E5%99%A8)。\n", "\n", "此外MindSpore的领域开发库也提供了大量的预加载数据集,可以使用API一键下载使用。本教程将分别对不同的数据集(Dataset)加载方式:自定义数据集、标准格式数据集和常见数据集,数据变换(Transforms)和数据batch方法进行详细阐述。" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "from mindspore import dtype as mstype\n", "from mindspore.dataset import transforms\n", "from mindspore.dataset import vision\n", "from mindspore.dataset import MindDataset, GeneratorDataset, MnistDataset, NumpySlicesDataset\n", "from mindspore.mindrecord import FileWriter\n", "import matplotlib.pyplot as plt" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## 数据集加载\n", "\n", "`mindspore.dataset`模块提供了自定义数据集、标准格式数据集和一些常用的公开常用数据集的加载API。\n", "\n", "### 自定义数据集\n", "\n", "对于MindSpore暂不支持直接加载的数据集,可以构造自定义数据加载类或自定义数据集生成函数的方式来生成数据集,然后通过`GeneratorDataset`接口实现自定义方式的数据集加载。\n", "\n", "`GeneratorDataset`支持通过可随机访问数据集对象、可迭代数据集对象和生成器(generator)构造自定义数据集,下面分别对其进行介绍。\n", "\n", "#### 可随机访问数据集\n", "\n", "可随机访问数据集是实现了`__getitem__`和`__len__`方法的数据集,表示可以通过索引/键直接访问对应位置的数据样本。\n", "\n", "例如,当使用`dataset[idx]`访问这样的数据集时,可以读取dataset内容中第idx个样本或标签。" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Tensor(shape=[2], dtype=Float64, value= [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[1], dtype=Float64, value= [ 0.00000000e+00])]\n", "[Tensor(shape=[2], dtype=Float64, value= [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[1], dtype=Float64, value= [ 0.00000000e+00])]\n", "[Tensor(shape=[2], dtype=Float64, value= [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[1], dtype=Float64, value= [ 0.00000000e+00])]\n", "[Tensor(shape=[2], dtype=Float64, value= [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[1], dtype=Float64, value= [ 0.00000000e+00])]\n", "[Tensor(shape=[2], dtype=Float64, value= [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[1], dtype=Float64, value= [ 0.00000000e+00])]\n" ] } ], "source": [ "# Random-accessible object as input source\n", "class RandomAccessDataset:\n", " def __init__(self):\n", " self._data = np.ones((5, 2))\n", " self._label = np.zeros((5, 1))\n", " def __getitem__(self, index):\n", " return self._data[index], self._label[index]\n", " def __len__(self):\n", " return len(self._data)\n", "\n", "loader = RandomAccessDataset()\n", "dataset = GeneratorDataset(source=loader, column_names=[\"data\", \"label\"])\n", "\n", "for data in dataset:\n", " print(data)" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Tensor(shape=[], dtype=Int32, value= 2)]\n", "[Tensor(shape=[], dtype=Int32, value= 0)]\n", "[Tensor(shape=[], dtype=Int32, value= 1)]\n" ] } ], "source": [ "# list, tuple are also supported.\n", "loader = [np.array(0), np.array(1), np.array(2)]\n", "dataset = GeneratorDataset(source=loader, column_names=[\"data\"])\n", "\n", "for data in dataset:\n", " print(data)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### 可迭代数据集\n", "\n", "可迭代的数据集是实现了`__iter__`和`__next__`方法的数据集,表示可以通过迭代的方式逐步获取数据样本。这种类型的数据集特别适用于随机访问成本太高或者不可行的情况。\n", "\n", "例如,当使用`iter(dataset)`的形式访问数据集时,可以读取从数据库、远程服务器返回的数据流。\n", "\n", "下面构造一个简单迭代器,并将其加载至`GeneratorDataset`。" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Tensor(shape=[], dtype=Int32, value= 1)]\n", "[Tensor(shape=[], dtype=Int32, value= 2)]\n", "[Tensor(shape=[], dtype=Int32, value= 3)]\n", "[Tensor(shape=[], dtype=Int32, value= 4)]\n" ] } ], "source": [ "# Iterator as input source\n", "class IterableDataset():\n", " def __init__(self, start, end):\n", " '''init the class object to hold the data'''\n", " self.start = start\n", " self.end = end\n", " def __next__(self):\n", " '''iter one data and return'''\n", " return next(self.data)\n", " def __iter__(self):\n", " '''reset the iter'''\n", " self.data = iter(range(self.start, self.end))\n", " return self\n", "\n", "loader = IterableDataset(1, 5)\n", "dataset = GeneratorDataset(source=loader, column_names=[\"data\"])\n", "\n", "for d in dataset:\n", " print(d)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### 生成器\n", "\n", "生成器也属于可迭代的数据集类型,其直接依赖Python的生成器类型`generator`返回数据,直至生成器抛出`StopIteration`异常。\n", "\n", "下面构造一个生成器,并将其加载至`GeneratorDataset`。" ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Tensor(shape=[], dtype=Int32, value= 3)]\n", "[Tensor(shape=[], dtype=Int32, value= 4)]\n", "[Tensor(shape=[], dtype=Int32, value= 5)]\n" ] } ], "source": [ "# Generator\n", "def my_generator(start, end):\n", " for i in range(start, end):\n", " yield i\n", "\n", "# since a generator instance can be only iterated once, we need to wrap it by lambda to generate multiple instances\n", "dataset = GeneratorDataset(source=lambda: my_generator(3, 6), column_names=[\"data\"])\n", "\n", "for d in dataset:\n", " print(d)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### 标准格式数据集\n", "\n", "对于MindSpore暂不支持直接加载的数据集,可以将数据集转换成**MindRecord格式**数据集,然后通过`MindDataset`接口实现数据集加载。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "首先通过**MindRecord格式**接口[FileWriter](https://www.mindspore.cn/docs/zh-CN/r2.4.1/api_python/mindspore.mindrecord.html#mindspore.mindrecord.FileWriter)创建一个新的**MindRecord格式**数据集,其中每个样本包含`file_name`、`label`和`data`三个字段。" ] }, { "cell_type": "code", "execution_count": 55, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "if os.path.exists(\"./test.mindrecord\"):\n", " os.remove(\"./test.mindrecord\")\n", "if os.path.exists(\"./test.mindrecord.db\"):\n", " os.remove(\"./test.mindrecord.db\")\n", "writer = FileWriter(file_name=\"test.mindrecord\", shard_num=1, overwrite=True)\n", "schema_json = {\"file_name\": {\"type\": \"string\"},\n", " \"label\": {\"type\": \"int32\"},\n", " \"data\": {\"type\": \"int32\", \"shape\": [-1]}}\n", "writer.add_schema(schema_json, \"test_schema\")\n", "for i in range(4):\n", " data = [{\"file_name\": str(i) + \".jpg\",\n", " \"label\": i,\n", " \"data\": np.array([i]*(i+1), dtype=np.int32)}]\n", " writer.write_raw_data(data)\n", "writer.commit()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "然后通过**MindDataset**接口读取**MindRecord格式**数据集。" ] }, { "cell_type": "code", "execution_count": 56, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Tensor(shape=[1], dtype=Int32, value= [0]), Tensor(shape=[], dtype=String, value= '0.jpg'), Tensor(shape=[], dtype=Int32, value= 0)]\n", "[Tensor(shape=[2], dtype=Int32, value= [1, 1]), Tensor(shape=[], dtype=String, value= '1.jpg'), Tensor(shape=[], dtype=Int32, value= 1)]\n", "[Tensor(shape=[3], dtype=Int32, value= [2, 2, 2]), Tensor(shape=[], dtype=String, value= '2.jpg'), Tensor(shape=[], dtype=Int32, value= 2)]\n", "[Tensor(shape=[4], dtype=Int32, value= [3, 3, 3, 3]), Tensor(shape=[], dtype=String, value= '3.jpg'), Tensor(shape=[], dtype=Int32, value= 3)]\n" ] } ], "source": [ "dataset = MindDataset(\"test.mindrecord\", shuffle=False)\n", "for data in dataset:\n", " print(data)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### 常用数据集\n", "\n", "我们使用**Mnist**数据集作为样例,介绍使用常用数据集的加载方法。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`mindspore.dataset`提供的接口**仅支持解压后的数据文件**,因此我们使用`download`库下载数据集并解压。" ] }, { "cell_type": "code", "execution_count": 57, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip (10.3 MB)\n", "\n", "file_sizes: 100%|██████████████████████████| 10.8M/10.8M [00:01<00:00, 5.78MB/s]\n", "Extracting zip file...\n", "Successfully downloaded / unzipped to ./\n" ] } ], "source": [ "# Download data from open datasets\n", "from download import download\n", "\n", "url = \"https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/\" \\\n", " \"notebook/datasets/MNIST_Data.zip\"\n", "path = download(url, \"./\", kind=\"zip\", replace=True)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "压缩文件删除后,直接加载,可以看到其数据类型为MnistDataset。" ] }, { "cell_type": "code", "execution_count": 58, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<class 'mindspore.dataset.engine.datasets_vision.MnistDataset'>\n" ] } ], "source": [ "train_dataset = MnistDataset(\"MNIST_Data/train\", shuffle=False)\n", "print(type(train_dataset))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "使用迭代器循环输出数据,下面定义一个可视化函数,迭代**Mnist**数据集中9张图片进行展示。" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "def visualize(dataset):\n", " figure = plt.figure(figsize=(4, 4))\n", " cols, rows = 3, 3\n", "\n", " plt.subplots_adjust(wspace=0.5, hspace=0.5)\n", "\n", " for idx, (image, label) in enumerate(dataset.create_tuple_iterator()):\n", " figure.add_subplot(rows, cols, idx + 1)\n", " plt.title(int(label))\n", " plt.axis(\"off\")\n", " plt.imshow(image.asnumpy().squeeze(), cmap=\"gray\")\n", " if idx == cols * rows - 1:\n", " break\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "<Figure size 400x400 with 9 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "visualize(train_dataset)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## 数据变换\n", "\n", "通常情况下,直接加载的原始数据并不能直接送入神经网络进行训练,此时我们需要对其进行数据预处理。MindSpore提供不同种类的数据变换(Transforms),配合数据处理Pipeline来实现数据预处理,所有的Transforms均可通过`.map(...)`方法传入。\n", "\n", "1. `.map(...)`操作可以针对数据集指定列(column)添加数据变换(Transforms),将数据变换应用于该列数据的每个元素,并返回包含变换后元素的新数据集。\n", "\n", "2. `.map(...)`操作可以执行Dataset模块提供的内置数据变换操作,也可以执行用户自定义的变换操作。\n", "\n", "`mindspore.dataset`提供了面向图像、文本、音频等不同数据类型的内置数据变换操作,同时也支持使用自定义数据变换操作。下面分别对其进行介绍。\n", "\n", "### 内置数据变换操作\n", "\n", "`mindspore.dataset`提供的内置数据变换:[vision数据变换](https://www.mindspore.cn/docs/zh-CN/r2.4.1/api_python/mindspore.dataset.transforms.html#%E8%A7%86%E8%A7%89) , [nlp数据变换](https://www.mindspore.cn/docs/zh-CN/r2.4.1/api_python/mindspore.dataset.transforms.html#%E6%96%87%E6%9C%AC) , [audio数据变换](https://www.mindspore.cn/docs/zh-CN/r2.4.1/api_python/mindspore.dataset.transforms.html#%E9%9F%B3%E9%A2%91)。\n", "\n", "下面举例对**Mnist**数据集中**data**使用 `Rescale`、`Normalize`和`HWC2CHW`操作,对**label**使用`TypeCast`操作。\n", "\n", "1. Rescale:用于调整图像像素值的大小,包括两个参数:rescale(缩放因子)和shift(平移因子),图像的每个像素将根据这两个参数进行调整,输出的像素值为$output_{i} = input_{i} * rescale + shift$。\n", "\n", "2. Normalize:用于对输入图像的归一化,包括三个参数:mean(图像每个通道的均值)、std(图像每个通道的标准差)和is_hwc(bool值,输入图像的格式。True为(height, width, channel),False为(channel, height, width))。图像的每个通道将根据mean和std进行调整,计算公式为 $output_{c} = \\frac{input_{c} - mean_{c}}{std_{c}}$ ,其中 $c$ 代表通道索引。\n", "\n", "3. HWC2CHW:用于转换图像格式,将图像从HWC转换成CHW。" ] }, { "cell_type": "code", "execution_count": 61, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1, 28, 28) Float32\n", "() Int32\n" ] } ], "source": [ "train_dataset = MnistDataset('MNIST_Data/train')\n", "train_dataset = train_dataset.map(operations=[vision.Rescale(1.0 / 255.0, 0),\n", " vision.Normalize(mean=(0.1307,), std=(0.3081,)),\n", " vision.HWC2CHW()],\n", " input_columns=['image'])\n", "train_dataset = train_dataset.map(operations=[transforms.TypeCast(mstype.int32)],\n", " input_columns=['label'])\n", "for data in train_dataset:\n", " print(data[0].shape, data[0].dtype)\n", " print(data[1].shape, data[1].dtype)\n", " break" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### 自定义数据变换操作\n", "\n", "下面举例对**Mnist**数据集中**data**使用 自定义的`Rescale`、自定义的`Normalize`和 自定义的`HWC2CHW`操作,对**label**使用自定义的`TypeCast`操作。" ] }, { "cell_type": "code", "execution_count": 62, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1, 28, 28) Float32\n", "() Int32\n" ] } ], "source": [ "train_dataset = MnistDataset('MNIST_Data/train')\n", "def rescale_normalize_hwc2chw(input_col):\n", " trans_result = input_col / 255.0\n", " trans_result = (trans_result - 0.1307) / 0.3081\n", " trans_result = trans_result.transpose(2, 0, 1)\n", " return trans_result.astype(np.float32)\n", "train_dataset = train_dataset.map(operations=rescale_normalize_hwc2chw,\n", " input_columns=['image'])\n", "def typecast(input_col):\n", " trans_result = input_col.astype(np.int32)\n", " return trans_result\n", "train_dataset = train_dataset.map(operations=typecast,\n", " input_columns=['label'])\n", "for data in train_dataset:\n", " print(data[0].shape, data[0].dtype)\n", " print(data[1].shape, data[1].dtype)\n", " break" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## 数据batch\n", "\n", "batch意义在于将多个样本打包为固定大小的`batch`,且在有限硬件资源下使用梯度下降进行模型优化的折中方法,可以保证梯度下降的随机性和优化计算量。\n", "\n", "一般我们会设置一个固定的batch size,将连续的数据分为若干批(batch)。batch后的数据增加一维,大小为`batch_size`。\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 63, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(2, 2) (2,)\n", "(2, 2) (2,)\n", "(2, 2) (2,)\n" ] } ], "source": [ "data = ([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], [0, 1, 0, 1, 0, 1])\n", "dataset = NumpySlicesDataset(data=data, column_names=[\"data\", \"label\"], shuffle=False)\n", "dataset = dataset.batch(2)\n", "for data in dataset.create_tuple_iterator():\n", " print(data[0].shape, data[1].shape)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## 数据集迭代器\n", "\n", "数据集Pipeline定义完成后,一般以迭代方式获取数据,然后送入神经网络中进行训练。我们可以用[create_tuple_iterator](https://www.mindspore.cn/docs/zh-CN/r2.4.1/api_python/dataset/dataset_method/iterator/mindspore.dataset.Dataset.create_tuple_iterator.html)或[create_dict_iterator](https://www.mindspore.cn/docs/zh-CN/r2.4.1/api_python/dataset/dataset_method/iterator/mindspore.dataset.Dataset.create_dict_iterator.html)接口创建数据迭代器,并迭代访问数据。\n", "\n", "访问的数据类型默认为`Tensor`;若设置`output_numpy=True`,访问的数据类型为`Numpy`。" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "下面展示`create_tuple_iterator`迭代器输出的结果。" ] }, { "cell_type": "code", "execution_count": 64, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Tensor(shape=[2], dtype=Int32, value= [2, 4]), Tensor(shape=[2], dtype=Int32, value= [0, 1])]\n", "[Tensor(shape=[2], dtype=Int32, value= [6, 8]), Tensor(shape=[2], dtype=Int32, value= [0, 1])]\n", "[Tensor(shape=[2], dtype=Int32, value= [10, 12]), Tensor(shape=[2], dtype=Int32, value= [0, 1])]\n", "[Tensor(shape=[2], dtype=Int32, value= [14, 16]), Tensor(shape=[2], dtype=Int32, value= [0, 1])]\n" ] } ], "source": [ "data = ([1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 0, 1, 0, 1, 0, 1])\n", "dataset = NumpySlicesDataset(data=data, column_names=[\"data\", \"label\"], shuffle=False)\n", "dataset = dataset.map(lambda x: x * 2, input_columns=[\"data\"])\n", "dataset = dataset.batch(2)\n", "for data in dataset.create_tuple_iterator():\n", " print(data)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "下面展示`create_dict_iterator`迭代器输出的结果。" ] }, { "cell_type": "code", "execution_count": 65, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'data': Tensor(shape=[2], dtype=Int32, value= [2, 4]), 'label': Tensor(shape=[2], dtype=Int32, value= [0, 1])}\n", "{'data': Tensor(shape=[2], dtype=Int32, value= [6, 8]), 'label': Tensor(shape=[2], dtype=Int32, value= [0, 1])}\n", "{'data': Tensor(shape=[2], dtype=Int32, value= [10, 12]), 'label': Tensor(shape=[2], dtype=Int32, value= [0, 1])}\n", "{'data': Tensor(shape=[2], dtype=Int32, value= [14, 16]), 'label': Tensor(shape=[2], dtype=Int32, value= [0, 1])}\n" ] } ], "source": [ "data = ([1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 0, 1, 0, 1, 0, 1])\n", "dataset = NumpySlicesDataset(data=data, column_names=[\"data\", \"label\"], shuffle=False)\n", "dataset = dataset.map(lambda x: x * 2, input_columns=[\"data\"])\n", "dataset = dataset.batch(2)\n", "for data in dataset.create_dict_iterator():\n", " print(data)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.7.5 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" }, "vscode": { "interpreter": { "hash": "5109d816b82be14675a6b11f8e0f0d2e80f029176ed3710d54e125caa8520dfd" } } }, "nbformat": 4, "nbformat_minor": 4 }