mindspore.dataset.Dataset.split

Dataset.split(sizes, randomize=True)[source]

Split the dataset into smaller, non-overlapping datasets.

Parameters
  • sizes (Union[list[int], list[float]]) –

    If a list of integers [s1, s2, …, sn] is provided, the dataset will be split into n datasets of size s1, size s2, …, size sn respectively. If the sum of all input sizes does not equal the original dataset size, an error will throw. If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1 and must sum to 1, otherwise an error will throw. The dataset will be split into n Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the original dataset. If after rounding:

    • Any size equals 0, an error will occur.

    • The sum of split sizes < K, the difference of K - sigma(round(fi * k)) will be added to the first split.

    • The sum of split sizes > K, the difference of sigma(round(fi * K)) - K will be removed from the first large enough split such that it will have at least 1 row after removing the difference.

  • randomize (bool, optional) – Determines whether or not to split the data randomly. Default: True. If True, the data will be randomly split. Otherwise, each split will be created with consecutive rows from the dataset.

Note

  1. Dataset cannot be sharded if split is going to be called.

  2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead. Shuffling the dataset may not be deterministic, which means the data in each split will be different in each epoch.

Returns

tuple(Dataset), a tuple of datasets that have been split.

Raises
  • RuntimeError – If get_dataset_size returns None or is not supported for this dataset.

  • RuntimeError – If sizes is list of integers and sum of all elements in sizes does not equal the dataset size.

  • RuntimeError – If sizes is list of float and there is a split with size 0 after calculations.

  • RuntimeError – If the dataset is sharded prior to calling split.

  • ValueError – If sizes is list of float and not all floats are between 0 and 1, or if the floats don’t sum to 1.

Examples

>>> # TextFileDataset is not a mappable dataset, so this non-optimized split will be called.
>>> # Since many datasets have shuffle on by default, set shuffle to False if split will be called!
>>> dataset = ds.TextFileDataset(text_file_dataset_dir, shuffle=False)
>>> train_dataset, test_dataset = dataset.split([0.9, 0.1])