mindspore.dataset.Dataset.batch
- Dataset.batch(batch_size, drop_remainder=False, num_parallel_workers=None, **kwargs)[source]
Combine batch_size number of consecutive rows into batch which apply per_batch_map to the samples first.
For any column, all the elements within that column must have the same shape.
Refer to the following figure for the execution process:
Note
The order of using repeat and batch reflects the number of batches and per_batch_map. It is recommended that the repeat operation applied after the batch operation finished.
When using Data Sinking in Graph mode, the input shape of the network should keep consistent. You should set drop_remainder to "True" to discard the last incomplete batch of data, or supplement/remove samples to ensure the dataset size is divisible by batch_size.
- Parameters
batch_size (Union[int, Callable]) – The number of rows each batch is created with. An int or callable object which takes exactly 1 parameter, BatchInfo.
drop_remainder (bool, optional) – Determines whether or not to drop the last block whose data row number is less than batch size. Default:
False
. IfTrue
, and if there are less than batch_size rows available to make the last batch, then those rows will be dropped and not propagated to the child node.num_parallel_workers (int, optional) – Number of workers(threads) to process the dataset in parallel. Default:
None
.**kwargs –
per_batch_map (Callable[[List[numpy.ndarray], …, List[numpy.ndarray], BatchInfo], (List[numpy.ndarray], …, List[numpy.ndarray])], optional): Per batch map callable. Default:
None
. A callable which takes (List[numpy.ndarray], …, List[numpy.ndarray], BatchInfo) as input parameters. Each list[numpy.ndarray] represents a batch of numpy.ndarray on a given column. The number of lists should match with the number of entries in input_columns. The last parameter of the callable should always be a BatchInfo object. Per_batch_map should return (list[numpy.ndarray], list[numpy.ndarray], …). The length of each list in output should be the same as the input. output_columns is required if the number of output lists is different from input.input_columns (Union[str, list[str]], optional): List of names of the input columns. The size of the list should match with signature of per_batch_map callable. Default:
None
.output_columns (Union[str, list[str]], optional): List of names assigned to the columns outputted by the last operation. This parameter is mandatory if len(input_columns) != len(output_columns). The size of this list must match the number of output columns of the last operation. Default:
None
, output columns will have the same name as the input columns, i.e., the columns will be replaced.python_multiprocessing (bool, optional): Parallelize Python function per_batch_map with multiprocessing or multithreading mode,
True
means multiprocessing,False
means multithreading If per_batch_map is a I/O bound task, use multithreading mode. If per_batch_map is a CPU bound task, it is recommended to use multiprocessing mode. Default:False
, use python multithreading mode.max_rowsize(Union[int, list[int]], optional): Maximum size of row in MB that is used for shared memory allocation to copy data between processes, the total occupied shared memory will increase as
num_parallel_workers
andmindspore.dataset.config.set_prefetch_size()
increase. If set to -1, shared memory will be dynamically allocated with the actual size of data. This is only used ifpython_multiprocessing
is set to True. If it is an int value, it representsinput_columns
andoutput_columns
use this value as the unit to create shared memory. If it is a list, the first element represents theinput_columns
use this value as the unit to create shared memory, and the second element representsoutput_columns
use this value as the unit to create shared memory. Default:None
, allocate shared memory dynamically.
- Returns
Dataset, a new dataset with the above operation applied.
Examples
>>> # 1) Create a dataset where every 5 rows are combined into a batch >>> # and drops the last incomplete batch if there is one. >>> import mindspore.dataset as ds >>> from PIL import Image >>> >>> cifar10_dataset_dir = "/path/to/cifar10_dataset_directory" >>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, num_samples=10) >>> dataset = dataset.batch(5, True) >>> >>> # 2) resize image according to its batch number, if it's 5-th batch, resize to (5^2, 5^2) = (25, 25) >>> def np_resize(col, BatchInfo): ... output = col.copy() ... s = (BatchInfo.get_batch_num() + 1) ** 2 ... index = 0 ... for c in col: ... img = Image.fromarray(c.astype('uint8')).convert('RGB') ... img = img.resize((s, s)) ... output[index] = np.array(img) ... index += 1 ... return (output,) >>> dataset = dataset.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize) >>> >>> # 3) Create a dataset where its batch size is dynamic >>> # Define a callable batch size function and let batch size increase 1 each time. >>> def add_one(BatchInfo): ... return BatchInfo.get_batch_num() + 1 >>> dataset = dataset.batch(batch_size=add_one, drop_remainder=True)