比较与torch.utils.data.DataLoader的差异
torch.utils.data.DataLoader
class torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=None, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *,
prefetch_factor=2, persistent_workers=False)
更多内容详见torch.utils.data.DataLoader。
mindspore.dataset.GeneratorDataset
class mindspore.dataset.GeneratorDataset(
source, column_names=None, column_types=None, schema=None,
num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None,
num_shards=None, shard_id=None, python_multiprocessing=True, max_rowsize=None)
差异对比
PyTorch:DataLoader需要接收一个数据加载类、采样器,及批处理、混洗、多进程并行度等参数,以实现一个具有采样、分批、混洗等功能的数据迭代对象。
其中dataset
参数支持继承自torch.utils.data.Dataset
的自定义类,或传入由torchvision.datasets
、torchtext.datasets
、torchaudio.datasets
等组件中预定义好的数据集加载类。
MindSpore:GeneratorDataset需要接收一个数据加载类、采样器、混洗、分片和多进程并行性来创建一个用于数据迭代的迭代器。 此API与PyTorch的DataLoader功能定位一样,均是用于加载自定义的数据集,但参数列表差异较大,下面的多个代码示例将演示如何使用2个API实现同样的功能。
分类 |
子类 |
PyTorch |
MindSpore |
差异 |
---|---|---|---|---|
参数 |
参数1 |
dataset |
source |
定义数据集加载逻辑的对象 |
参数2 |
batch_size |
- |
MindSpore通过 |
|
参数3 |
shuffle |
shuffle |
- |
|
参数4 |
sampler |
sampler |
- |
|
参数5 |
batch_sampler |
- |
MindSpore不支持 |
|
参数6 |
num_workers |
num_parallel_workers |
- |
|
参数7 |
collate_fn |
- |
MindSpore通过 |
|
参数8 |
pin_memory |
- |
MindSpore不支持 |
|
参数9 |
drop_last |
- |
MindSpore通过 |
|
参数10 |
timeout |
- |
MindSpore不支持 |
|
参数11 |
worker_init_fn |
- |
MindSpore不支持 |
|
参数12 |
multiprocessing_context |
- |
多进程上下文,MindSpore不支持 |
|
参数13 |
generator |
- |
自定义索引生成器,MindSpore不支持 |
|
参数14 |
prefetch_factor |
- |
定义在 |
|
参数15 |
persistent_workers |
- |
指定遍历完一次数据后是否释放数据集对象, MindSpore通过 |
|
参数16 |
- |
column_names |
指定数据集生成的列名 |
|
参数17 |
- |
column_types |
指定生成数据集各个数据列的数据类型 |
|
参数18 |
- |
schema |
数据格式策略,用于指定读取数据列的数据类型、数据维度等信 |
|
参数19 |
- |
num_samples |
指定从数据集中读取的样本数 |
|
参数20 |
- |
num_shards |
指定分布式训练时将数据集进行划分的分片数 |
|
参数21 |
- |
shard_id |
指定分布式训练时使用的分片ID号 |
|
参数22 |
- |
python_multiprocessing |
指定是否启用Python多进程模式加速运算 |
|
参数23 |
- |
max_rowsize |
指定在多进程之间复制数据时,共享内存分配的最大空间 |
代码示例1
定义一个迭代类型的数据集类与一个随机访问类型的数据集类,并通过DataLoader/GeneratorDataset加载。注意DataLoader的shuffle参数默认行为是False,GeneratorDataset的shuffle默认行为是True。
# Torch
import torch
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
self.start = start
self.end = end
def __iter__(self):
return iter(range(self.start, self.end))
ds = MyIterableDataset(start=3, end=7)
# Single-process loading
print(list(torch.utils.data.DataLoader(ds, num_workers=0, shuffle=False)))
# Out: [tensor([3]), tensor([4]), tensor([5]), tensor([6])]
class MyMapDataset(torch.utils.data.Dataset):
def __init__(self):
super(MyMapDataset).__init__()
self.data = [1, 2, 3, 4]
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
ds = MyMapDataset()
# Single-process loading
print(list(torch.utils.data.DataLoader(ds)))
# Out: [tensor([1]), tensor([2]), tensor([3]), tensor([4])]
# MindSpore
import mindspore as ms
class MyIterableDataset():
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
return iter(range(self.start, self.end))
ds = MyIterableDataset(start=3, end=7)
# Single-process loading
print(list(ms.dataset.GeneratorDataset(ds, column_names=["data"], num_parallel_workers=1, shuffle=False)))
# Out: [[Tensor(shape=[], dtype=Int64, value= 3)], [Tensor(shape=[], dtype=Int64, value= 4)], [Tensor(shape=[], dtype=Int64, value= 5)], [Tensor(shape=[], dtype=Int64, value= 6)]]
class MyMapDataset():
def __init__(self):
super(MyMapDataset).__init__()
self.data = [1, 2, 3, 4]
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
ds = MyMapDataset()
# Single-process loading
print(list(ms.dataset.GeneratorDataset(ds, column_names=["data"], shuffle=False)))
# Out: [[Tensor(shape=[], dtype=Int64, value= 1)], [Tensor(shape=[], dtype=Int64, value= 2)], [Tensor(shape=[], dtype=Int64, value= 3)], [Tensor(shape=[], dtype=Int64, value= 4)]]
代码示例2
定义一个数据集类,并对数据进行batch为2的批处理。
# Torch
import torch
class MyMapDataset(torch.utils.data.Dataset):
def __init__(self):
super(MyMapDataset).__init__()
self.data = [1, 2, 3, 4, 5]
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
ds = MyMapDataset()
dataloader = torch.utils.data.DataLoader(ds, batch_size=2, drop_last=True)
print(list(dataloader))
# Out: [tensor([1, 2]), tensor([3, 4])]
# MindSpore
import mindspore as ms
class MyMapDataset():
def __init__(self):
super(MyMapDataset).__init__()
self.data = [1, 2, 3, 4, 5]
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
ds = MyMapDataset()
dataloader = ms.dataset.GeneratorDataset(ds, column_names=["data"], shuffle=False)
dataloader = dataloader.batch(2, drop_remainder=True)
print(list(dataloader))
# Out: [[Tensor(shape=[2], dtype=Int64, value= [1, 2])], [Tensor(shape=[2], dtype=Int64, value= [3, 4])]]
代码示例3
定义一个数据集类,进行批处理时引入collate_fn逻辑。
# Torch
import torch
class MyMapDataset(torch.utils.data.Dataset):
def __init__(self):
super(MyMapDataset).__init__()
self.data = torch.Tensor([1, 2, 3, 4, 5])
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
def my_collate_fn(batch):
for i, _ in enumerate(batch):
batch[i] = batch[i] + 2
return torch.stack(batch)
ds = MyMapDataset()
dataloader = torch.utils.data.DataLoader(ds, batch_size=2, drop_last=True, collate_fn=my_collate_fn)
print(list(dataloader))
# Out: [tensor([3., 4.]), tensor([5., 6.])]
# MindSpore
import mindspore as ms
import numpy as np
class MyMapDataset():
def __init__(self):
super(MyMapDataset).__init__()
self.data = [1, 2, 3, 4, 5]
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
def my_collate_fn(batch, batchinfo):
for i, _ in enumerate(batch):
batch[i] = batch[i] + 2
return np.stack(batch),
ds = MyMapDataset()
dataloader = ms.dataset.GeneratorDataset(ds, column_names=["data"], shuffle=False)
dataloader = dataloader.batch(2, drop_remainder=True, per_batch_map=my_collate_fn)
print(list(dataloader))
# Out: [[Tensor(shape=[2], dtype=Int64, value= [3, 4])], [Tensor(shape=[2], dtype=Int64, value= [5, 6])]]