高级数据集管理
Ascend
GPU
CPU
进阶
数据准备
MindSpore可以加载常见的数据集或自定义的数据集,这部分功能在初级教程中进行了部分介绍。加载自定义数据集有两种途径:
通过
GeneratorDataset
对象加载,使用方法可参考初级教程-自定义数据集。将数据集转换为MindRecord,即MindSpore数据格式,通过读取MindRecord文件进行加载数据。
如果用户想要获得更好的性能体验,可以将数据集转换为MindRecord,从而方便地加载到MindSpore中进行训练。
MindRecord的性能优化如下:
实现多变的用户数据统一存储、访问,训练数据读取更加简便。
数据聚合存储,高效读取,且方便管理、移动。
高效的数据编解码操作,对用户透明、无感知。
可以灵活控制分区的大小,实现分布式训练。
常见数据集转换MindRecord可参考官方编程指南中的MindSpore数据格式转换,自定义数据集转换可参考下文。
MindRecord的目标是归一化用户的数据集,并进一步通过MindDataset
实现数据的读取,用于训练过程。下面对这两步进行说明。
自定义数据集转换为MindRecord
首先,下载需要处理的图片数据transform.jpg
作为待处理的原始数据。
创建文件夹目录./datasets/convert_dataset_to_mindrecord/data_to_mindrecord/
用于存放所有的转换数据集。
创建文件夹目录./datasets/convert_dataset_to_mindrecord/images/
用于存放下载下来的图片数据。
以下示例代码完成图片下载和文件夹的创建,并将图片移动到指定位置。
[ ]:
import os
import requests
import tarfile
import zipfile
requests.packages.urllib3.disable_warnings()
def download_dataset(url, target_path):
"""下载并解压数据集"""
if not os.path.exists(target_path):
os.makedirs(target_path)
download_file = url.split("/")[-1]
if not os.path.exists(download_file):
res = requests.get(url, stream=True, verify=False)
if download_file.split(".")[-1] not in ["tgz", "zip", "tar", "gz"]:
download_file = os.path.join(target_path, download_file)
with open(download_file, "wb") as f:
for chunk in res.iter_content(chunk_size=512):
if chunk:
f.write(chunk)
if download_file.endswith("zip"):
z = zipfile.ZipFile(download_file, "r")
z.extractall(path=target_path)
z.close()
if download_file.endswith(".tar.gz") or download_file.endswith(".tar") or download_file.endswith(".tgz"):
t = tarfile.open(download_file)
names = t.getnames()
for name in names:
t.extract(name, target_path)
t.close()
print("The {} file is downloaded and saved in the path {} after processing".format(os.path.basename(url), target_path))
download_dataset("https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/transform.jpg", "./datasets/convert_dataset_to_mindrecord/images/")
if not os.path.exists("./datasets/convert_dataset_to_mindrecord/data_to_mindrecord/"):
os.makedirs("./datasets/convert_dataset_to_mindrecord/data_to_mindrecord/")
下载的图片数据文件的目录结构如下:
./datasets/convert_dataset_to_mindrecord/images/
└── transform.jpg
导入文件写入工具类FileWriter
。
[2]:
from mindspore.mindrecord import FileWriter
创建FileWriter对象,传入文件名及分片数量,进行覆盖写。
[3]:
data_record_path = './datasets/convert_dataset_to_mindrecord/data_to_mindrecord/test.mindrecord'
writer = FileWriter(file_name=data_record_path, shard_num=4, overwrite=True)
定义数据集结构文件Schema,调用write_raw_data
接口写入数据,最后调用commit
接口生成本地数据文件。
Schema文件主要包含字段名name
、字段数据类型type
和字段各维度维数shape
:
字段名:字段的引用名称,可以包含字母、数字和下划线。
字段数据类型:包含int32、int64、float32、float64、string、bytes。
字段维数:一维数组用[-1]表示,更高维度可表示为[m, n, …],其中m、n为各维度维数。
如果字段有属性
shape
,则用户传入write_raw_data
接口的数据必须为numpy.ndarray
类型,对应数据类型必须为int32、int64、float32、float64。
[4]:
# 定义schema
data_schema = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
writer.add_schema(data_schema, "test_schema")
# 数据准备
file_name = "./datasets/convert_dataset_to_mindrecord/images/transform.jpg"
with open(file_name, "rb") as f:
bytes_data = f.read()
data = [{"file_name": "transform.jpg", "label": 1, "data": bytes_data}]
indexes = ["file_name", "label"]
writer.add_index(indexes)
# 数据写入
writer.write_raw_data(data)
# 生成本地数据
writer.commit()
[4]:
MSRStatus.SUCCESS
该示例会生成8个文件,成为MindRecord数据集。test.mindrecord0
和test.mindrecord0.db
称为1个MindRecord文件,其中test.mindrecord0
为数据文件,test.mindrecord0.db
为索引文件,生成的文件为:
./datasets/convert_dataset_to_mindrecord/data_to_mindrecord/
├── test.mindrecord0
├── test.mindrecord0.db
├── test.mindrecord1
├── test.mindrecord1.db
├── test.mindrecord2
├── test.mindrecord2.db
├── test.mindrecord3
└── test.mindrecord3.db
0 directories, 8 files
读取MindRecord数据集
导入读取类mindspore.dataset
。
[5]:
import mindspore.dataset as ds
首先使用MindDataset
读取MindRecord数据集,然后对数据创建字典迭代器,并通过迭代器读取一条数据记录。
[6]:
file_name = './datasets/convert_dataset_to_mindrecord/data_to_mindrecord/test.mindrecord0'
# 创建MindDataset
define_data_set = ds.MindDataset(dataset_files=file_name)
# 创建字典迭代器并通过迭代器读取数据记录
count = 0
for item in define_data_set.create_dict_iterator(output_numpy=True):
print("sample: {}".format(item))
count += 1
print("Got {} samples".format(count))
sample: {'data': array([255, 216, 255, ..., 159, 255, 217], dtype=uint8), 'file_name': array(b'transform.jpg', dtype='|S13'), 'label': array(1, dtype=int32)}
Got 1 samples