mindspore.dataset.GeneratorDataset
- class mindspore.dataset.GeneratorDataset(source, column_names=None, column_types=None, schema=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None, python_multiprocessing=True, max_rowsize=6)[source]
A source dataset that generates data from Python by invoking Python data source each epoch.
The column names and column types of generated dataset depend on Python data defined by users.
- Parameters
source (Union[Callable, Iterable, Random Accessible]) – A generator callable object, an iterable Python object or a random accessible Python object. Callable source is required to return a tuple of NumPy arrays as a row of the dataset on source().next(). Iterable source is required to return a tuple of NumPy arrays as a row of the dataset on iter(source).next(). Random accessible source is required to return a tuple of NumPy arrays as a row of the dataset on source[idx].
column_names (Union[str, list[str]], optional) – List of column names of the dataset. Default:
None
. Users are required to provide either column_names or schema.column_types (list[mindspore.dtype], optional) – List of column data types of the dataset. Default:
None
. If provided, sanity check will be performed on generator output.schema (Union[str, Schema], optional) – Data format policy, which specifies the data types and shapes of the data column to be read. Both JSON file path and objects constructed by
mindspore.dataset.Schema
are acceptable. Default:None
.num_samples (int, optional) – The number of samples to be included in the dataset. Default:
None
, all images.num_parallel_workers (int, optional) – Number of worker threads/subprocesses used to fetch the dataset in parallel. Default:
1
.shuffle (bool, optional) – Whether or not to perform shuffle on the dataset. Random accessible input is required. Default:
None
, expected order behavior shown in the table below.sampler (Union[Sampler, Iterable], optional) – Object used to choose samples from the dataset. Random accessible input is required. Default:
None
, expected order behavior shown in the table below.num_shards (int, optional) – Number of shards that the dataset will be divided into. Default:
None
. Random accessible input is required. When this argument is specified, num_samples reflects the maximum sample number of per shard.shard_id (int, optional) – The shard ID within num_shards . Default:
None
. This argument must be specified only when num_shards is also specified. Random accessible input is required.python_multiprocessing (bool, optional) – Parallelize Python operations with multiple worker process. This option could be beneficial if the Python operation is computational heavy. Default:
True
.max_rowsize (int, optional) – Maximum size of row in MB that is used for shared memory allocation to copy data between processes, the total occupied shared memory will increase as
num_parallel_workers
andmindspore.dataset.config.set_prefetch_size()
increase. This is only used if python_multiprocessing is set to True. Default: 16.
- Raises
RuntimeError – If source raises an exception during execution.
RuntimeError – If len of column_names does not match output len of source.
ValueError – If num_parallel_workers exceeds the max thread numbers.
ValueError – If sampler and shuffle are specified at the same time.
ValueError – If sampler and sharding are specified at the same time.
ValueError – If num_shards is specified but shard_id is None.
ValueError – If shard_id is specified but num_shards is None.
ValueError – If shard_id is not in range of [0, num_shards ).
- Tutorial Examples:
Note
If you configure python_multiprocessing=True (Default:
True
) and num_parallel_workers>1 (default:1
) indicates that the multi-process mode is started for data load acceleration. At this time, as the datasetiterates, the memory consumption of the subprocess will gradually increase, mainly because the subprocess of the user-defined dataset obtains the member variables from the main process in the Copy On Write way. Example: If you define a dataset with __ init__ function which contains a large number of member variable data (for example, a very large file name list is loaded during the dataset construction) and uses the multi-process mode, which may cause the problem of OOM (the estimated total memory usage is: (num_parallel_workers+1) * size of the parent process ). The simplest solution is to replace Python objects (such as list/dict/int/float/string) with non referenced data types (such as Pandas, Numpy or PyArrow objects) for member variables, or load less meta data in member variables, or configure python_multiprocessing=False to use multi-threading mode.There are several classes/functions that can help you reduce the size of member variables, and you can choose to use them:
1.
mindspore.dataset.utils.LineReader
: Use this class to initialize your text file object in the __init__ function. Then read the file content based on the line number of the object with the __getitem__ function.Input source accepts user-defined Python functions (PyFuncs), Do not add network computing operators from mindspore.nn and mindspore.ops or others into this source .
The parameters num_samples , shuffle , num_shards , shard_id can be used to control the sampler used in the dataset, and their effects when combined with parameter sampler are as follows.
Parameter sampler
Parameter num_shards / shard_id
Parameter shuffle
Parameter num_samples
Sampler Used
mindspore.dataset.Sampler type
None
None
None
sampler
numpy.ndarray,list,tuple,int type
/
/
num_samples
SubsetSampler(indices = sampler , num_samples = num_samples )
iterable type
/
/
num_samples
IterSampler(sampler = sampler , num_samples = num_samples )
None
num_shards / shard_id
None / True
num_samples
DistributedSampler(num_shards = num_shards , shard_id = shard_id , shuffle = True , num_samples = num_samples )
None
num_shards / shard_id
False
num_samples
DistributedSampler(num_shards = num_shards , shard_id = shard_id , shuffle = False , num_samples = num_samples )
None
None
None / True
None
RandomSampler(num_samples = num_samples )
None
None
None / True
num_samples
RandomSampler(replacement = True , num_samples = num_samples )
None
None
False
num_samples
SequentialSampler(num_samples = num_samples )
Examples
>>> import mindspore.dataset as ds >>> import numpy as np >>> >>> # 1) Multidimensional generator function as callable input. >>> def generator_multidimensional(): ... for i in range(64): ... yield (np.array([[i, i + 1], [i + 2, i + 3]]),) >>> >>> dataset = ds.GeneratorDataset(source=generator_multidimensional, column_names=["multi_dimensional_data"]) >>> >>> # 2) Multi-column generator function as callable input. >>> def generator_multi_column(): ... for i in range(64): ... yield np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]) >>> >>> dataset = ds.GeneratorDataset(source=generator_multi_column, column_names=["col1", "col2"]) >>> >>> # 3) Iterable dataset as iterable input. >>> class MyIterable: ... def __init__(self): ... self._index = 0 ... self._data = np.random.sample((5, 2)) ... self._label = np.random.sample((5, 1)) ... ... def __next__(self): ... if self._index >= len(self._data): ... raise StopIteration ... else: ... item = (self._data[self._index], self._label[self._index]) ... self._index += 1 ... return item ... ... def __iter__(self): ... self._index = 0 ... return self ... ... def __len__(self): ... return len(self._data) >>> >>> dataset = ds.GeneratorDataset(source=MyIterable(), column_names=["data", "label"]) >>> >>> # 4) Random accessible dataset as random accessible input. >>> class MyAccessible: ... def __init__(self): ... self._data = np.random.sample((5, 2)) ... self._label = np.random.sample((5, 1)) ... ... def __getitem__(self, index): ... return self._data[index], self._label[index] ... ... def __len__(self): ... return len(self._data) >>> >>> dataset = ds.GeneratorDataset(source=MyAccessible(), column_names=["data", "label"]) >>> >>> # list, dict, tuple of Python is also random accessible >>> dataset = ds.GeneratorDataset(source=[(np.array(0),), (np.array(1),), (np.array(2),)], column_names=["col"])
Pre-processing Operation
Apply a function in this dataset. |
|
Concatenate the dataset objects in the input list. |
|
Filter dataset by prediction. |
|
Map func to each row in dataset and flatten the result. |
|
Apply each operation in operations to this dataset. |
|
The specified columns will be selected from the dataset and passed into the pipeline with the order specified. |
|
Rename the columns in input datasets. |
|
Repeat this dataset count times. |
|
Reset the dataset for next epoch. |
|
Save the dynamic data processed by the dataset pipeline in common dataset format. |
|
Shuffle the dataset by creating a cache with the size of buffer_size . |
|
Skip the first N elements of this dataset. |
|
Split the dataset into smaller, non-overlapping datasets. |
|
Take the first specified number of samples from the dataset. |
|
Zip the datasets in the sense of input tuple of datasets. |
Batch
Combine batch_size number of consecutive rows into batch which apply per_batch_map to the samples first. |
|
Bucket elements according to their lengths. |
|
Combine batch_size number of consecutive rows into batch which apply pad_info to the samples first. |
Iterator
Create an iterator over the dataset. |
|
Create an iterator over the dataset. |
Attribute
Return the size of batch. |
|
Get the mapping dictionary from category names to category indexes. |
|
Return the names of the columns in dataset. |
|
Return the number of batches in an epoch. |
|
Get the replication times in RepeatDataset. |
|
Get the column index, which represents the corresponding relationship between the data column order and the network when using the sink mode. |
|
Get the number of classes in a dataset. |
|
Get the shapes of output data. |
|
Get the types of output data. |
Apply Sampler
Add a child sampler for the current dataset. |
|
Replace the last child sampler of the current dataset, remaining the parent sampler unchanged. |
Others
Return a transferred Dataset that transfers data through a device. |
|
Release a blocking condition and trigger callback with given data. |
|
Add a blocking condition to the input Dataset and a synchronize action will be applied. |
|
Serialize a pipeline into JSON string and dump into file if filename is provided. |