Function Differences with torch.utils.data.DataLoader
torch.utils.data.DataLoader
class torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=None, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *,
prefetch_factor=2, persistent_workers=False)
For more information, see torch.utils.data.DataLoader.
MindSpore
Datasets in MindSpore do not require a loader.
Differences
PyTorch: The dataset, sampler object and batching, shuffling, parallel parameters need to be passed to the DataLoader, to implement parallel data iteration with sampling, batching and shuffling.
MindSpore: The sampler object and batching, shuffling, parallel parameters can be passed to the dataset object while construction, or defined by the class methods of dataset. Therefore, the dataset itself will have the functions of sampling, batching, shuffling and parallel data iteration without a loader.
Code Example
import torch
import mindspore.dataset as ds
import numpy as np
np.random.seed(0)
data = np.random.random((6, 2, 2))
# The following implements data iteration with PyTorch.
dataset = torch.utils.data.TensorDataset(torch.Tensor(data))
sampler = torch.utils.data.SequentialSampler(dataset)
loader = torch.utils.data.DataLoader(dataset, batch_size=3)
print(next(iter(loader)))
# Out:
# [tensor([[[0.5488, 0.7152],
# [0.6028, 0.5449]],
#
# [[0.4237, 0.6459],
# [0.4376, 0.8918]],
#
# [[0.9637, 0.3834],
# [0.7917, 0.5289]]])]
# The following implements data iteration with MindSpore.
sampler = ds.SequentialSampler()
dataset = ds.NumpySlicesDataset(data, sampler=sampler)
dataset = dataset.batch(batch_size=3)
iterator = dataset.create_dict_iterator()
print(next(iter(iterator)))
# Out:
# {'column_0': Tensor(shape=[3, 2, 2], dtype=Float64, value=
# [[[ 5.48813504e-01, 7.15189366e-01],
# [ 6.02763376e-01, 5.44883183e-01]],
# [[ 4.23654799e-01, 6.45894113e-01],
# [ 4.37587211e-01, 8.91773001e-01]],
# [[ 9.63662761e-01, 3.83441519e-01],
# [ 7.91725038e-01, 5.28894920e-01]]])}