mindspore.dataset.Dataset.batch
- Dataset.batch(batch_size, drop_remainder=False, num_parallel_workers=None, **kwargs)[source]
Combine batch_size number of consecutive rows into batch which apply per_batch_map to the samples first.
For any column, all the elements within that column must have the same shape.
Refer to the following figure for the execution process:
Note
The order of using repeat and batch reflects the number of batches and per_batch_map. It is recommended that the repeat operation applied after the batch operation finished.
- Parameters
batch_size (Union[int, Callable]) – The number of rows each batch is created with. An int or callable object which takes exactly 1 parameter, BatchInfo.
drop_remainder (bool, optional) – Determines whether or not to drop the last block whose data row number is less than batch size. Default:
False
. IfTrue
, and if there are less than batch_size rows available to make the last batch, then those rows will be dropped and not propagated to the child node.num_parallel_workers (int, optional) – Number of workers(threads) to process the dataset in parallel. Default:
None
.**kwargs –
per_batch_map (Callable[[List[numpy.ndarray], …, List[numpy.ndarray], BatchInfo], (List[numpy.ndarray], …, List[numpy.ndarray])], optional): Per batch map callable. Default:
None
. A callable which takes (List[numpy.ndarray], …, List[numpy.ndarray], BatchInfo) as input parameters. Each list[numpy.ndarray] represents a batch of numpy.ndarray on a given column. The number of lists should match with the number of entries in input_columns. The last parameter of the callable should always be a BatchInfo object. Per_batch_map should return (list[numpy.ndarray], list[numpy.ndarray], …). The length of each list in output should be the same as the input. output_columns is required if the number of output lists is different from input.input_columns (Union[str, list[str]], optional): List of names of the input columns. The size of the list should match with signature of per_batch_map callable. Default:
None
.output_columns (Union[str, list[str]], optional): List of names assigned to the columns outputted by the last operation. This parameter is mandatory if len(input_columns) != len(output_columns). The size of this list must match the number of output columns of the last operation. Default:
None
, output columns will have the same name as the input columns, i.e., the columns will be replaced.python_multiprocessing (bool, optional): Parallelize Python function per_batch_map with multi-processing or multi-threading mode,
True
means multi-processing,False
means multi-threading If per_batch_map is a I/O bound task, use multi-threading mode. If per_batch_map is a CPU bound task, it is recommended to use multi-processing mode. Default:False
, use python multi-threading mode.max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy data between processes. This is only used if python_multiprocessing is set to
True
. Default:16
.
- Returns
BatchDataset, dataset batched.
Examples
>>> # 1) Create a dataset where every 5 rows are combined into a batch >>> # and drops the last incomplete batch if there is one. >>> import mindspore.dataset as ds >>> from PIL import Image >>> >>> cifar10_dataset_dir = "/path/to/cifar10_dataset_directory" >>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, num_samples=10) >>> dataset = dataset.batch(5, True) >>> >>> # 2) resize image according to its batch number, if it's 5-th batch, resize to (5^2, 5^2) = (25, 25) >>> def np_resize(col, BatchInfo): ... output = col.copy() ... s = (BatchInfo.get_batch_num() + 1) ** 2 ... index = 0 ... for c in col: ... img = Image.fromarray(c.astype('uint8')).convert('RGB') ... img = img.resize((s, s)) ... output[index] = np.array(img) ... index += 1 ... return (output,) >>> dataset = dataset.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize) >>> >>> # 3) Create a dataset where its batch size is dynamic >>> # Define a callable batch size function and let batch size increase 1 each time. >>> def add_one(BatchInfo): ... return BatchInfo.get_batch_num() + 1 >>> dataset = dataset.batch(batch_size=add_one, drop_remainder=True)