# 数据处理 [](https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/source_zh_cn/migration_guide/model_development/dataset.md) 本章节主要对网络迁移中数据处理相关的注意事项加以说明,基础的数据处理概念请参考: [数据处理](https://www.mindspore.cn/tutorials/zh-CN/r2.3.0rc1/beginner/dataset.html) [自动数据增强](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.0rc1/dataset/augment.html) [轻量化数据处理](https://mindspore.cn/tutorials/zh-CN/r2.3.0rc1/advanced/dataset/eager.html) [数据处理性能优化](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.0rc1/dataset/optimize.html) ## 数据处理差异对比 MindSpore和PyTorch的数据构建基本流程主要包括两个方面:数据集加载和数据增强。下面从读取常见数据集处理流程、读取自定义数据集处理流程两方面来比较二者的写法差异: ### 处理常见数据集 MindSpore提供了很多不同领域的[常见数据集的加载接口](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc1/api_python/mindspore.dataset.html)。 除以上业界常用数据集外,MindSpore还开发了MindRecord数据格式以应对高效的读取、超大型数据存储与读取场景,感兴趣可以参阅[MindRecord](https://www.mindspore.cn/tutorials/zh-CN/r2.3.0rc1/advanced/dataset/record.html)。由于此文章是介绍同类API及写法差异,故选取一个较为经典的数据集API作为迁移对比示例。其他数据集接口差异详细可参考PyTorch与MindSpore API映射表的 [torchaudio](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc1/note/api_mapping/pytorch_api_mapping.html#torchaudio)、[torchtext](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc1/note/api_mapping/pytorch_api_mapping.html#torchtext)、[torchvision](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc1/note/api_mapping/pytorch_api_mapping.html#torchvision) 模块。 这里以FashionMnistDataset举例。下图展示了PyTorch的API使用方法(左边部分),以及MindSpore的API使用方法(右边部分)。主要的读取流程为:使用FashionMnist API加载源数据集,再使用transforms对数据内容进行变换,最后根据对数据集进行`batch`操作。两侧代码对应的关键部分,均使用颜色框进行了标记。  可以看到MindSpore和PyTorch在读取常见数据有以下不同: 1. 获取和读取数据集的方式不同: * PyTorch既可以先将数据集下载到本地然后传给API接口进行读取和解析数据,也可以通过设置API接口的参数 `download` 来下载数据集然后进行读取。 * MindSpore需要先将数据集下载到本地然后传给API接口进行读取和解析数据。 2. 对数据集本身进行混洗、批处理、并行加载等功能支持的方式不同: * PyTorch支持在 `DataLoader` 中配置参数 `shuffle` 、`batch` 、`num_workers` 等来实现相应功能。 * 由于接口API设计的差异,MindSpore则直接在数据集API接口,通过参数 `shuffle` 、 `num_parallel_workers` 承载了混洗、并行加载功能,然后在数据增强结束后,使用 `batch` 操作将数据集中连续的数据合并为一个批处理数据。`batch` 操作详情请参考[batch](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc1/api_python/dataset/dataset_method/batch/mindspore.dataset.Dataset.batch.html#mindspore.dataset.Dataset.batch),由于API设计差异,需要注意MindSpore中 `batch` 操作的参数 `drop_remainder` 与 PyTorch的DataLoader中的参数 `drop_last` 含义一致。 除了FashionMnist API,所有的数据集加载API均有相同的参数设计,上述例子中的 `batch` 操作均适用于所有数据集API。下面以一个可以返回假图像的数据集API `FakeImageDataset` 再次举例并使用相关的数据操作: ```python import mindspore.dataset as ds dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0)\ .batch(32, drop_remainder=True) print("When drop_remainder=True, the last batch will be drop, the total batch number is ", dataset.get_dataset_size()) # 1000 // 32 = 31 dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0)\ .batch(32, drop_remainder=False) print("When drop_remainder=False, the last batch will not be drop, the total batch number is ", dataset.get_dataset_size()) # ceil(1000 / 32) = 32 ``` 运行结果: ```text When drop_remainder=True, the last batch will be drop, the total batch number is 31 When drop_remainder=False, the last batch will not be drop, the total batch number is 32 ``` batch操作也可以使用一些batch内的增强操作,详情可参考[YOLOv3](https://gitee.com/mindspore/models/blob/r2.3/official/cv/YOLOv3/src/yolo_dataset.py#L177)。 上面提到的**数据集加载API含有相同的参数**,在这里介绍一些常用的: | 属性 | 介绍 | | ---- | ---- | | num_samples(int) | 规定数据总的sample数 | | shuffle(bool) | 是否对数据做随机打乱 | | sampler(Sampler) | 数据取样器,可以自定义数据打乱、分配,`sampler` 设置和 `num_shards` 、`shard_id` 互斥 | | num_shards(int) | 用于分布式场景,将数据分为多少份,与 `shard_id` 配合使用 | | shard_id(int) | 用于分布式场景,取第几份数据(0~n-1,n为设置的 `num_shards` ),与 `num_shards` 配合使用 | | num_parallel_workers(int) | 并行配置的线程数 | 这里还是以 `FakeImageDataset` 举个例子: ```python import mindspore.dataset as ds dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0) print(dataset.get_dataset_size()) # 1000 dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0, num_samples=3) print(dataset.get_dataset_size()) # 3 dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0, num_shards=8, shard_id=0) print(dataset.get_dataset_size()) # 1000 / 8 = 125 ``` 运行结果: ```text 1000 3 125 ``` 3. 数据增强操作使用的方式不同:
PyTorch | MindSpore |
```python trans = torchvision.transforms.Resize(...) mnist_train = torchvision.datasets.FashionMNIST(..., transforms=trans, ...) ``` |
```python trans = mindspore.dataset.vision.Resize(...) mnist_train = mindspore.dataset.FashionMnistDataset(...) mnist_train = mnist_train.map(trans, ...) ``` |
PyTorch | MindSpore |
```python ... img_resize = torchvision.transforms.Resize(...)(input_ids) img_resize = torchvision.transforms.ToTensor()(img_resize) tmp_tensor = torch.tensor(np.ones_like(img_resize)) img_resize = torch.mul(img_resize, tmp_tensor) img_resize = torchvision.transforms.Normalize(...)(img_resize) ... ``` |
```python ... img_resize = mindspore.dataset.vision.Resize(...)(input_ids) img_resize = mindspore.dataset.vision.ToTensor()(img_resize) tmp_array = np.ones_like(img_resize) img_resize = np.multiply(img_resize, tmp_array) img_resize = mindspore.dataset.vision.Normalize(...)(img_resize) ... ``` |
PyTorch | MindSpore |
```python import numpy as np import torch from torch.utils.data import DataLoader x = np.random.randint(0, 255, size=(20, 32, 32, 3)) tensor_x = torch.Tensor(x) dataloader = DataLoader(tensor_x, batch_size=10) for i, data in enumerate(dataloader): print(i, data.shape) ``` 运行结果: ```text 0 torch.Size([10, 32, 32, 3]) 1 torch.Size([10, 32, 32, 3]) ``` |
```python import numpy as np import mindspore.dataset as ds x = np.random.randint(0, 255, size=(20, 32, 32, 3)) dataset = ds.GeneratorDataset(x, column_names=["data"]) dataset = dataset.batch(10, drop_remainder=True) for data in dataset: print(data[0].shape) ``` 运行结果: ```text (10, 32, 32, 3) (10, 32, 32, 3) ``` |