mindspore.dataset.Dataset.split
- Dataset.split(sizes, randomize=True)[source]
Split the dataset into smaller, non-overlapping datasets.
- Parameters
sizes (Union[list[int], list[float]]) –
If a list of integers [s1, s2, …, sn] is provided, the dataset will be split into n datasets of size s1, size s2, …, size sn respectively. If the sum of all input sizes does not equal the original dataset size, an error will throw. If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1 and must sum to 1, otherwise an error will throw. The dataset will be split into n Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the original dataset. If after rounding:
Any size equals 0, an error will occur.
The sum of split sizes < K, the difference of K - sigma(round(fi * k)) will be added to the first split.
The sum of split sizes > K, the difference of sigma(round(fi * K)) - K will be removed from the first large enough split such that it will have at least 1 row after removing the difference.
randomize (bool, optional) – Determines whether or not to split the data randomly. Default:
True
. If True, the data will be randomly split. Otherwise, each split will be created with consecutive rows from the dataset.
Note
Dataset cannot be sharded if split is going to be called.
It is strongly recommended to not shuffle the dataset, but use randomize=True instead. Shuffling the dataset may not be deterministic, which means the data in each split will be different in each epoch.
- Returns
Tuple[Dataset], a tuple of new datasets split from the original one.
- Raises
RuntimeError – If get_dataset_size returns None or is not supported for this dataset.
RuntimeError – If sizes is list of integers and sum of all elements in sizes does not equal the dataset size.
RuntimeError – If sizes is list of float and there is a split with size 0 after calculations.
RuntimeError – If the dataset is sharded prior to calling split.
ValueError – If sizes is list of float and not all floats are between 0 and 1, or if the floats don't sum to 1.
Examples
>>> # Split the data into train part and test part. >>> import mindspore.dataset as ds >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1") >>> train_dataset, test_dataset = dataset.split([0.9, 0.1])