数据集加载
概述
MindSpore支持加载图像领域常用的数据集,用户可以直接使用mindspore.dataset
中对应的类实现数据集的加载。目前支持的常用数据集及对应的数据集类如下表所示。
图像数据集 |
数据集类 |
数据集简介 |
---|---|---|
MNIST |
MnistDataset |
MNIST是一个大型手写数字图像数据集,拥有60,000张训练图像和10,000张测试图像,常用于训练各种图像处理系统。 |
CIFAR-10 |
Cifar10Dataset |
CIFAR-10是一个微小图像数据集,包含10种类别下的60,000张32x32大小彩色图像,平均每种类别6,000张,其中5,000张为训练集,1,000张为测试集。 |
CIFAR-100 |
Cifar100Dataset |
CIFAR-100与CIFAR-10类似,但拥有100种类别,平均每种类别600张,其中500张为训练集,100张为测试集。 |
CelebA |
CelebADataset |
CelebA是一个大型人脸图像数据集,包含超过200,000张名人人脸图像,每张图像拥有40个特征标记。 |
PASCAL-VOC |
VOCDataset |
PASCAL-VOC是一个常用图像数据集,被广泛用于目标检测、图像分割等计算机视觉领域。 |
COCO |
CocoDataset |
COCO是一个大型目标检测、图像分割、姿态估计数据集。 |
CLUE |
CLUEDataset |
CLUE是一个大型中文语义理解数据集。 |
MindSpore还支持加载多种数据存储格式下的数据集,用户可以直接使用mindspore.dataset
中对应的类加载磁盘中的数据文件。目前支持的数据格式及对应加载方式如下表所示。
数据格式 |
数据集类 |
数据格式简介 |
---|---|---|
MindRecord |
MindDataset |
MindRecord是MindSpore的自研数据格式,具有读写高效、易于分布式处理等优势。 |
Manifest |
ManifestDataset |
Manifest是华为ModelArts支持的一种数据格式,描述了原始文件和标注信息,可用于标注、训练、推理场景。 |
TFRecord |
TFRecordDataset |
TFRecord是TensorFlow定义的一种二进制数据文件格式。 |
NumPy |
NumpySlicesDataset |
NumPy数据源指的是已经读入内存中的NumPy arrays格式数据集。 |
Text File |
TextFileDataset |
Text File指的是常见的文本格式数据。 |
CSV File |
CSVDataset |
CSV指逗号分隔值,其文件以纯文本形式存储表格数据。 |
MindSpore也同样支持使用GeneratorDataset
自定义数据集的加载方式,用户可以根据需要实现自己的数据集类。
数据集类 |
数据格式简介 |
---|---|
GeneratorDataset |
用户自定义的数据集读取、处理的方式。 |
NumpySlicesDataset |
用户自定义的由NumPy构建数据集的方式。 |
更多详细的数据集加载接口说明,参见API文档。
常用数据集加载
下面将介绍几种常用数据集的加载方式。
CIFAR-10/100数据集
下载CIFAR-10数据集并解压到指定位置:
[1]:
!wget -N https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz
!mkdir -p datasets
!tar -xzf cifar-10-binary.tar.gz -C datasets
!mkdir -p datasets/cifar-10-batches-bin/train datasets/cifar-10-batches-bin/test
!mv -f datasets/cifar-10-batches-bin/test_batch.bin datasets/cifar-10-batches-bin/test
!mv -f datasets/cifar-10-batches-bin/data_batch*.bin datasets/cifar-10-batches-bin/batches.meta.txt datasets/cifar-10-batches-bin/train
!tree ./datasets/cifar-10-batches-bin
./datasets/cifar-10-batches-bin
├── readme.html
├── test
│ └── test_batch.bin
└── train
├── batches.meta.txt
├── data_batch_1.bin
├── data_batch_2.bin
├── data_batch_3.bin
├── data_batch_4.bin
└── data_batch_5.bin
2 directories, 8 files
下面的样例通过Cifar10Dataset
接口加载CIFAR-10数据集,使用顺序采样器获取其中5个样本,然后展示了对应图片的形状和标签。
CIFAR-100数据集和MNIST数据集的加载方式也与之类似。
[2]:
import mindspore.dataset as ds
DATA_DIR = "./datasets/cifar-10-batches-bin/train/"
sampler = ds.SequentialSampler(num_samples=5)
dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)
for data in dataset.create_dict_iterator():
print("Image shape:", data['image'].shape, ", Label:", data['label'])
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 4
Image shape: (32, 32, 3) , Label: 1
VOC数据集
VOC数据集有多个版本,此处以VOC2012为例。下载VOC2012数据集并解压,目录结构如下。
└─ VOCtrainval_11-May-2012
└── VOCdevkit
└── VOC2012
├── Annotations
├── ImageSets
├── JPEGImages
├── SegmentationClass
└── SegmentationObject
下面的样例通过VOCDataset
接口加载VOC2012数据集,分别演示了将任务指定为分割(Segmentation)和检测(Detection)时的原始图像形状和目标形状。
import mindspore.dataset as ds
DATA_DIR = "VOCtrainval_11-May-2012/VOCdevkit/VOC2012/"
dataset = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", num_samples=2, decode=True, shuffle=False)
print("[Segmentation]:")
for data in dataset.create_dict_iterator():
print("image shape:", data["image"].shape)
print("target shape:", data["target"].shape)
dataset = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", num_samples=1, decode=True, shuffle=False)
print("[Detection]:")
for data in dataset.create_dict_iterator():
print("image shape:", data["image"].shape)
print("bbox shape:", data["bbox"].shape)
输出结果:
[Segmentation]:
image shape: (281, 500, 3)
target shape: (281, 500, 3)
image shape: (375, 500, 3)
target shape: (375, 500, 3)
[Detection]:
image shape: (442, 500, 3)
bbox shape: (2, 4)
COCO数据集
COCO数据集有多个版本,此处以COCO2017的验证数据集为例。下载COCO2017的验证集、检测任务标注和全景分割任务标注并解压,只取其中的验证集部分,按以下目录结构存放。
└─ COCO
├── val2017
└── annotations
├── instances_val2017.json
├── panoptic_val2017.json
└── person_keypoints_val2017.json
下面的样例通过CocoDataset
接口加载COCO2017数据集,分别演示了将任务指定为目标检测(Detection)、背景分割(Stuff)、关键点检测(Keypoint)和全景分割(Panoptic)时获取到的不同数据。
import mindspore.dataset as ds
DATA_DIR = "COCO/val2017/"
ANNOTATION_FILE = "COCO/annotations/instances_val2017.json"
KEYPOINT_FILE = "COCO/annotations/person_keypoints_val2017.json"
PANOPTIC_FILE = "COCO/annotations/panoptic_val2017.json"
dataset = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", num_samples=1)
for data in dataset.create_dict_iterator():
print("Detection:", data.keys())
dataset = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff", num_samples=1)
for data in dataset.create_dict_iterator():
print("Stuff:", data.keys())
dataset = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint", num_samples=1)
for data in dataset.create_dict_iterator():
print("Keypoint:", data.keys())
dataset = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic", num_samples=1)
for data in dataset.create_dict_iterator():
print("Panoptic:", data.keys())
输出结果:
Detection: dict_keys(['image', 'bbox', 'category_id', 'iscrowd'])
Stuff: dict_keys(['image', 'segmentation', 'iscrowd'])
Keypoint: dict_keys(['image', 'keypoints', 'num_keypoints'])
Panoptic: dict_keys(['image', 'bbox', 'category_id', 'iscrowd', 'area'])
特定格式数据集加载
下面将介绍几种特定格式数据集文件的加载方式。
MindRecord数据格式
MindRecord是MindSpore定义的一种数据格式,使用MindRecord能够获得更好的性能提升。
阅读数据格式转换章节,了解如何将数据集转化为MindSpore数据格式。
执行本例之前需下载对应的测试数据test_mindrecord.zip
并解压到指定位置,执行如下命令:
[3]:
!wget -N https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/test_mindrecord.zip
!unzip -o ./test_mindrecord.zip -d ./datasets/mindspore_dataset_loading/
!tree ./datasets/mindspore_dataset_loading/
./datasets/mindspore_dataset_loading/
├── test.mindrecord
└── test.mindrecord.db
0 directories, 2 files
下面的样例通过MindDataset
接口加载MindRecord文件,并展示已加载数据的标签。
[4]:
import mindspore.dataset as ds
DATA_FILE = ["./datasets/mindspore_dataset_loading/test.mindrecord"]
mindrecord_dataset = ds.MindDataset(DATA_FILE)
for data in mindrecord_dataset.create_dict_iterator(output_numpy=True):
print(data.keys())
dict_keys(['chinese', 'english'])
dict_keys(['chinese', 'english'])
dict_keys(['chinese', 'english'])
Manifest数据格式
Manifest是华为ModelArts支持的数据格式文件,详细说明请参见Manifest文档。
本次示例需下载测试数据test_manifest.zip
并将其解压到指定位置,执行如下命令:
[5]:
!wget -N https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/test_manifest.zip
!unzip -o ./test_manifest.zip -d ./datasets/mindspore_dataset_loading/test_manifest/
!tree ./datasets/mindspore_dataset_loading/test_manifest/
./datasets/mindspore_dataset_loading/test_manifest/
├── eval
│ ├── 1.JPEG
│ └── 2.JPEG
├── test_manifest.json
└── train
├── 1.JPEG
└── 2.JPEG
2 directories, 5 files
下面的样例通过ManifestDataset
接口加载Manifest文件test_manifest.json
,并展示已加载数据的标签。
[6]:
import mindspore.dataset as ds
DATA_FILE = "./datasets/mindspore_dataset_loading/test_manifest/test_manifest.json"
manifest_dataset = ds.ManifestDataset(DATA_FILE)
for data in manifest_dataset.create_dict_iterator():
print(data["label"])
0
1
TFRecord数据格式
TFRecord是TensorFlow定义的一种二进制数据文件格式。
下面的样例通过TFRecordDataset
接口加载TFRecord文件,并介绍了两种不同的数据集格式设定方案。
下载tfrecord
测试数据test_tftext.zip
并解压到指定位置,执行如下命令:
[7]:
!wget -N https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/test_tftext.zip
!unzip -o ./test_tftext.zip -d ./datasets/mindspore_dataset_loading/test_tfrecord/
!tree ./datasets/mindspore_dataset_loading/test_tfrecord/
./datasets/mindspore_dataset_loading/test_tfrecord/
└── test_tftext.tfrecord
0 directories, 1 file
传入数据集路径或TFRecord文件列表,本例使用
test_tftext.tfrecord
,创建TFRecordDataset
对象。
[8]:
import mindspore.dataset as ds
DATA_FILE = "./datasets/mindspore_dataset_loading/test_tfrecord/test_tftext.tfrecord"
tfrecord_dataset = ds.TFRecordDataset(DATA_FILE)
for tf_data in tfrecord_dataset.create_dict_iterator():
print(tf_data.keys())
dict_keys(['chinese', 'line', 'words'])
dict_keys(['chinese', 'line', 'words'])
dict_keys(['chinese', 'line', 'words'])
用户可以通过编写Schema文件或创建Schema对象,设定数据集格式及特征。
编写Schema文件
将数据集格式和特征按JSON格式写入Schema文件。
columns
:列信息字段,需要根据数据集的实际列名定义。上面的示例中,数据集有三组数据,其列均为chinese
、line
和words
。然后在创建
TFRecordDataset
时将Schema文件路径传入。
[9]:
import os
import json
data_json = {
"columns": {
"chinese": {
"type": "uint8",
"rank": 1
},
"line" : {
"type": "int8",
"rank": 1
},
"words" : {
"type": "uint8",
"rank": 0
}
}
}
if not os.path.exists("dataset_schema_path"):
os.mkdir("dataset_schema_path")
SCHEMA_DIR = "dataset_schema_path/schema.json"
with open(SCHEMA_DIR, "w") as f:
json.dump(data_json,f,indent=4)
tfrecord_dataset = ds.TFRecordDataset(DATA_FILE, schema=SCHEMA_DIR)
for tf_data in tfrecord_dataset.create_dict_iterator():
print(tf_data.values())
dict_values([Tensor(shape=[57], dtype=UInt8, value= [230, 177, 159, 229, 183, 158, 229, 184, 130, 233, 149, 191, 230, 177, 159, 229, 164, 167, 230, 161, 165, 229, 143, 130,
229, 138, 160, 228, 186, 134, 233, 149, 191, 230, 177, 159, 229, 164, 167, 230, 161, 165, 231, 154, 132, 233, 128, 154,
232, 189, 166, 228, 187, 170, 229, 188, 143]), Tensor(shape=[22], dtype=Int8, value= [ 71, 111, 111, 100, 32, 108, 117, 99, 107, 32, 116, 111, 32, 101, 118, 101, 114, 121, 111, 110, 101, 46]), Tensor(shape=[32], dtype=UInt8, value= [229, 165, 179, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 101, 118, 101, 114, 121, 111, 110, 101,
99, 32, 32, 32, 32, 32, 32, 32])])
dict_values([Tensor(shape=[12], dtype=UInt8, value= [231, 148, 183, 233, 187, 152, 229, 165, 179, 230, 179, 170]), Tensor(shape=[19], dtype=Int8, value= [ 66, 101, 32, 104, 97, 112, 112, 121, 32, 101, 118, 101, 114, 121, 32, 100, 97, 121, 46]), Tensor(shape=[20], dtype=UInt8, value= [ 66, 101, 32, 32, 32, 104, 97, 112, 112, 121, 100, 97, 121, 32, 32, 98, 32, 32, 32, 32])])
dict_values([Tensor(shape=[48], dtype=UInt8, value= [228, 187, 138, 229, 164, 169, 229, 164, 169, 230, 176, 148, 229, 164, 170, 229, 165, 189, 228, 186, 134, 230, 136, 145,
228, 187, 172, 228, 184, 128, 232, 181, 183, 229, 142, 187, 229, 164, 150, 233, 157, 162, 231, 142, 169, 229, 144, 167
]), Tensor(shape=[20], dtype=Int8, value= [ 84, 104, 105, 115, 32, 105, 115, 32, 97, 32, 116, 101, 120, 116, 32, 102, 105, 108, 101, 46]), Tensor(shape=[16], dtype=UInt8, value= [ 84, 104, 105, 115, 116, 101, 120, 116, 102, 105, 108, 101, 97, 32, 32, 32])])
创建Schema对象
创建Schema对象,为其添加自定义字段,然后在创建数据集对象时传入。
[10]:
from mindspore import dtype as mstype
schema = ds.Schema()
schema.add_column('chinese', de_type=mstype.uint8)
schema.add_column('line', de_type=mstype.uint8)
tfrecord_dataset = ds.TFRecordDataset(DATA_FILE, schema=schema)
for tf_data in tfrecord_dataset.create_dict_iterator():
print(tf_data)
{'chinese': Tensor(shape=[12], dtype=UInt8, value= [231, 148, 183, 233, 187, 152, 229, 165, 179, 230, 179, 170]), 'line': Tensor(shape=[19], dtype=UInt8, value= [ 66, 101, 32, 104, 97, 112, 112, 121, 32, 101, 118, 101, 114, 121, 32, 100, 97, 121, 46])}
{'chinese': Tensor(shape=[48], dtype=UInt8, value= [228, 187, 138, 229, 164, 169, 229, 164, 169, 230, 176, 148, 229, 164, 170, 229, 165, 189, 228, 186, 134, 230, 136, 145,
228, 187, 172, 228, 184, 128, 232, 181, 183, 229, 142, 187, 229, 164, 150, 233, 157, 162, 231, 142, 169, 229, 144, 167
]), 'line': Tensor(shape=[20], dtype=UInt8, value= [ 84, 104, 105, 115, 32, 105, 115, 32, 97, 32, 116, 101, 120, 116, 32, 102, 105, 108, 101, 46])}
{'chinese': Tensor(shape=[57], dtype=UInt8, value= [230, 177, 159, 229, 183, 158, 229, 184, 130, 233, 149, 191, 230, 177, 159, 229, 164, 167, 230, 161, 165, 229, 143, 130,
229, 138, 160, 228, 186, 134, 233, 149, 191, 230, 177, 159, 229, 164, 167, 230, 161, 165, 231, 154, 132, 233, 128, 154,
232, 189, 166, 228, 187, 170, 229, 188, 143]), 'line': Tensor(shape=[22], dtype=UInt8, value= [ 71, 111, 111, 100, 32, 108, 117, 99, 107, 32, 116, 111, 32, 101, 118, 101, 114, 121, 111, 110, 101, 46])}
对比上述中的步骤2和步骤3,可以看出:
步骤 |
chinese |
line |
words |
---|---|---|---|
2 |
UInt8 |
Int8 |
UInt8 |
3 |
UInt8 |
UInt8 |
示例步骤2中的columns
中数据由chinese
(UInt8)、line
(Int8)和words
(UInt8)变为了示例步骤3中的chinese
(UInt8)、line
(UInt8),通过Schema对象,设定数据集的数据类型和特征,使得列中的数据类型和特征相应改变了。
NumPy数据格式
如果所有数据已经读入内存,可以直接使用NumpySlicesDataset
类将其加载。
下面的样例分别介绍了通过NumpySlicesDataset
加载arrays数据、 list数据和dict数据的方式。
加载NumPy arrays数据
[11]:
import numpy as np
import mindspore.dataset as ds
np.random.seed(6)
features, labels = np.random.sample((4, 2)), np.random.sample((4, 1))
data = (features, labels)
dataset = ds.NumpySlicesDataset(data, column_names=["col1", "col2"], shuffle=False)
for np_arr_data in dataset:
print(np_arr_data[0], np_arr_data[1])
[0.89286015 0.33197981] [0.33540785]
[0.82122912 0.04169663] [0.62251943]
[0.10765668 0.59505206] [0.43814143]
[0.52981736 0.41880743] [0.73588211]
加载Python list数据
[12]:
import mindspore.dataset as ds
data1 = [[1, 2], [3, 4]]
dataset = ds.NumpySlicesDataset(data1, column_names=["col1"], shuffle=False)
for np_list_data in dataset:
print(np_list_data[0])
[1 2]
[3 4]
加载Python dict数据
[13]:
import mindspore.dataset as ds
data1 = {"a": [1, 2], "b": [3, 4]}
dataset = ds.NumpySlicesDataset(data1, column_names=["col1", "col2"], shuffle=False)
for np_dic_data in dataset.create_dict_iterator():
print(np_dic_data)
{'col1': Tensor(shape=[], dtype=Int64, value= 1), 'col2': Tensor(shape=[], dtype=Int64, value= 3)}
{'col1': Tensor(shape=[], dtype=Int64, value= 2), 'col2': Tensor(shape=[], dtype=Int64, value= 4)}
CSV数据格式
下面的样例通过CSVDataset
加载CSV格式数据集文件,并展示了已加载数据的keys
。
下载测试数据test_csv.zip
并解压到指定位置,执行如下命令:
[14]:
!wget -N https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/test_csv.zip
!unzip -o ./test_csv.zip -d ./datasets/mindspore_dataset_loading/test_csv/
!tree ./datasets/mindspore_dataset_loading/test_csv/
./datasets/mindspore_dataset_loading/test_csv/
├── test1.csv
└── test2.csv
0 directories, 2 files
传入数据集路径或CSV文件列表,Text格式数据集文件的加载方式与CSV文件类似。
[15]:
import mindspore.dataset as ds
DATA_FILE = ["./datasets/mindspore_dataset_loading/test_csv/test1.csv","./datasets/mindspore_dataset_loading/test_csv/test2.csv"]
csv_dataset = ds.CSVDataset(DATA_FILE)
for csv_data in csv_dataset.create_dict_iterator(output_numpy=True):
print(csv_data.keys())
dict_keys(['a', 'b', 'c', 'd'])
dict_keys(['a', 'b', 'c', 'd'])
dict_keys(['a', 'b', 'c', 'd'])
dict_keys(['a', 'b', 'c', 'd'])
自定义数据集加载
对于目前MindSpore不支持直接加载的数据集,可以通过构造GeneratorDataset
对象实现自定义方式的加载,或者将其转换成MindRecord数据格式。下面分别展示几种不同的自定义数据集加载方法,为了便于对比,生成的随机数据保持相同。
构造数据集生成函数
构造生成函数定义数据返回方式,再使用此函数构建自定义数据集对象。此方法适用于简单场景。
[16]:
import numpy as np
import mindspore.dataset as ds
np.random.seed(58)
data = np.random.sample((5, 2))
label = np.random.sample((5, 1))
def GeneratorFunc():
for i in range(5):
yield (data[i], label[i])
dataset = ds.GeneratorDataset(GeneratorFunc, ["data", "label"])
for item in dataset.create_dict_iterator():
print(item["data"], item["label"])
[0.36510558 0.45120592] [0.78888122]
[0.49606035 0.07562207] [0.38068183]
[0.57176158 0.28963401] [0.16271622]
[0.30880446 0.37487617] [0.54738768]
[0.81585667 0.96883469] [0.77994068]
构造可迭代的数据集类
构造数据集类实现__iter__
和__next__
方法,再使用此类的对象构建自定义数据集对象。相比于直接定义生成函数,使用数据集类能够实现更多的自定义功能。
[17]:
import numpy as np
import mindspore.dataset as ds
class IterDatasetGenerator:
def __init__(self):
np.random.seed(58)
self.__index = 0
self.__data = np.random.sample((5, 2))
self.__label = np.random.sample((5, 1))
def __next__(self):
if self.__index >= len(self.__data):
raise StopIteration
else:
item = (self.__data[self.__index], self.__label[self.__index])
self.__index += 1
return item
def __iter__(self):
self.__index = 0
return self
def __len__(self):
return len(self.__data)
dataset_generator = IterDatasetGenerator()
dataset = ds.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=False)
for data in dataset.create_dict_iterator():
print(data["data"], data["label"])
[0.36510558 0.45120592] [0.78888122]
[0.49606035 0.07562207] [0.38068183]
[0.57176158 0.28963401] [0.16271622]
[0.30880446 0.37487617] [0.54738768]
[0.81585667 0.96883469] [0.77994068]
构造可随机访问的数据集类
构造数据集类实现__getitem__
方法,再使用此类的对象构建自定义数据集对象。此方法可以用于实现分布式训练。
[18]:
import numpy as np
import mindspore.dataset as ds
class GetDatasetGenerator:
def __init__(self):
np.random.seed(58)
self.__data = np.random.sample((5, 2))
self.__label = np.random.sample((5, 1))
def __getitem__(self, index):
return (self.__data[index], self.__label[index])
def __len__(self):
return len(self.__data)
dataset_generator = GetDatasetGenerator()
dataset = ds.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=False)
for data in dataset.create_dict_iterator():
print(data["data"], data["label"])
[0.36510558 0.45120592] [0.78888122]
[0.49606035 0.07562207] [0.38068183]
[0.57176158 0.28963401] [0.16271622]
[0.30880446 0.37487617] [0.54738768]
[0.81585667 0.96883469] [0.77994068]
如果用户希望实现分布式训练,则需要在此方式的基础上,在采样器类中实现__iter__
方法,每次返回采样数据的索引。需要补充的代码如下:
[19]:
import math
class MySampler():
def __init__(self, dataset, local_rank, world_size):
self.__num_data = len(dataset)
self.__local_rank = local_rank
self.__world_size = world_size
self.samples_per_rank = int(math.ceil(self.__num_data / float(self.__world_size)))
self.total_num_samples = self.samples_per_rank * self.__world_size
def __iter__(self):
indices = list(range(self.__num_data))
indices.extend(indices[:self.total_num_samples-len(indices)])
indices = indices[self.__local_rank:self.total_num_samples:self.__world_size]
return iter(indices)
def __len__(self):
return self.samples_per_rank
dataset_generator = GetDatasetGenerator()
sampler = MySampler(dataset_generator, local_rank=0, world_size=2)
dataset = ds.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=False, sampler=sampler)
for data in dataset.create_dict_iterator():
print(data["data"], data["label"])
[0.36510558 0.45120592] [0.78888122]
[0.57176158 0.28963401] [0.16271622]
[0.81585667 0.96883469] [0.77994068]