mindspore.dataset.WeightedRandomSampler

View Source On AtomGit
class mindspore.dataset.WeightedRandomSampler(weights, num_samples=None, replacement=True)[source]

Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).

Parameters:
  • weights (list[float, int]) – A sequence of weights, not necessarily summing up to 1.

  • num_samples (int, optional) – Number of elements to sample. Default: None , which means sample all elements.

  • replacement (bool, optional) – If True, put the sample ID back for the next draw. Default: True.

Raises:
  • TypeError – If elements of weights are not of type number.

  • TypeError – If num_samples is not of type int.

  • TypeError – If replacement is not of type bool.

  • RuntimeError – If weights is empty or all zero.

  • ValueError – If num_samples is a negative value.

Examples

>>> import mindspore.dataset as ds
>>> weights = [0.9, 0.01, 0.4, 0.8, 0.1, 0.1, 0.3]
>>>
>>> # creates a WeightedRandomSampler that will sample 4 elements without replacement
>>> sampler = ds.WeightedRandomSampler(weights, 4)
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
...                                 num_parallel_workers=8,
...                                 sampler=sampler)