# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The sampler module provides several samplers to generate data from datasets.
The provided samplers include: DistributedSampler, PKSampler, RandomSampler,
SequentialSampler, SubsetRandomSampler, and WeightedRandomSampler.
Users can also define a custom sampler by extending from the Sampler class.
"""
import numbers
import numpy as np
import mindspore._c_dataengine as cde
import mindspore.dataset as ds
from ..core import validator_helpers as validator
def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
"""
Create sampler based on user input.
Args:
num_samples (int): Number of samples.
input_sampler (Union[Iterable, Sampler]): Sampler from user.
shuffle (bool): Shuffle.
num_shards (int): Number of shard for sharding.
shard_id (int): Shard ID.
Returns:
Sampler, sampler selected based on user input.
"""
if input_sampler is not None:
# If the user provided a sampler, then it doesn't matter what the other args are because
# we are being asked specifically to use the given sampler.
# That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
# be None. Consider this example:
# sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
# data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
# In this case, the user has given different sample-related arguments that contradict each other.
# To prevent this, only allow the user to manually specify the sampler if those arguments are all None
if (isinstance(input_sampler, BuiltinSampler) and
(any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
raise ValueError(
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
if isinstance(input_sampler, BuiltinSampler):
return input_sampler
if not isinstance(input_sampler, str) and isinstance(input_sampler, (np.ndarray, list)):
return SubsetSampler(input_sampler, num_samples)
if not isinstance(input_sampler, str) and validator.is_iterable(input_sampler):
# in this case, the user passed in their own sampler object that's not of type BuiltinSampler
return IterSampler(input_sampler, num_samples)
if isinstance(input_sampler, int):
return SubsetSampler([input_sampler])
raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
if shuffle is None:
if num_shards is not None:
# If shuffle is not specified, sharding enabled, use distributed random sampler
shuffle = True
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle is not specified, sharding disabled, use random sampler
if num_samples is not None and num_samples != 0:
return RandomSampler(replacement=True, num_samples=num_samples)
return RandomSampler(num_samples=num_samples)
if shuffle is True:
if num_shards is not None:
# If shuffle enabled, sharding enabled, use distributed random sampler
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle enabled, sharding disabled, use random sampler
if num_samples is not None:
return RandomSampler(replacement=True, num_samples=num_samples)
return RandomSampler(num_samples=num_samples)
if num_shards is not None:
# If shuffle disabled, sharding enabled, use distributed sequential sampler
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle disabled, sharding disabled, use sequential sampler
return SequentialSampler(num_samples=num_samples)
class BuiltinSampler:
"""
Base class for BuiltinSampler.
User should not extend this class.
"""
def __init__(self, num_samples=None):
self.child_sampler = None
self.num_samples = num_samples
def parse(self):
""" Parse the sampler."""
def add_child(self, sampler):
"""
Add a sub-sampler for given sampler. The sub-sampler will receive all data from the
output of parent sampler and apply its sample logic to return new samples.
Args:
sampler (Sampler): Object used to choose samples from the dataset. Only builtin
samplers(DistributedSampler, PKSampler, RandomSampler, SequentialSampler,
SubsetRandomSampler, WeightedRandomSampler) are supported.
Examples:
>>> sampler = ds.SequentialSampler(start_index=0, num_samples=3)
>>> sampler.add_child(ds.RandomSampler(num_samples=2))
>>> dataset = ds.Cifar10Dataset(cifar10_dataset_dir, sampler=sampler)
"""
self.child_sampler = sampler
def get_child(self):
"""
Get the child sampler of given sampler.
Returns:
Sampler, The child sampler of given sampler.
Examples:
>>> sampler = ds.SequentialSampler(start_index=0, num_samples=3)
>>> sampler.add_child(ds.RandomSampler(num_samples=2))
>>> child_sampler = sampler.get_child()
"""
return self.child_sampler
def parse_child(self):
""" Parse the child sampler. """
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.parse()
return c_child_sampler
def parse_child_for_minddataset(self):
""" Parse the child sampler for MindRecord. """
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.parse_for_minddataset()
return c_child_sampler
def is_shuffled(self):
""" Not implemented. """
raise NotImplementedError("Sampler must implement is_shuffled.")
def is_sharded(self):
""" Not implemented. """
raise NotImplementedError("Sampler must implement is_sharded.")
def get_num_samples(self):
"""
All samplers can contain a numeric num_samples value (or it can be set to None).
A child sampler can exist or be None.
If a child sampler exists, then the child sampler count can be a numeric value or None.
These conditions impact the resultant sampler count that is used.
The following table shows the possible results from calling this function.
.. list-table::
:widths: 25 25 25 25
:header-rows: 1
* - child sampler
- num_samples
- child_samples
- result
* - T
- x
- y
- min(x, y)
* - T
- x
- None
- x
* - T
- None
- y
- y
* - T
- None
- None
- None
* - None
- x
- n/a
- x
* - None
- None
- n/a
- None
Returns:
int, the number of samples, or None.
Examples:
>>> sampler = ds.SequentialSampler(start_index=0, num_samples=3)
>>> num_samplers = sampler.get_num_samples()
"""
if self.child_sampler is not None:
child_samples = self.child_sampler.get_num_samples()
if self.num_samples is not None:
if child_samples is not None:
return min(self.num_samples, child_samples)
return self.num_samples
return child_samples
return self.num_samples
class Sampler(BuiltinSampler):
"""
Base class for user defined sampler.
A user defined sampler can be used with any existing dataset with sampler support.
A required _iter_() method should by overridden by the user for sample index generation.
An optional reset() method can be overridden for per repeat reset,
dataset_size and num_samples will be set by dataset once a dataset iterator is created.
Examples:
>>> class ReverseSampler(ds.Sampler):
... def __iter__(self):
... for i in range(self.dataset_size - 1, -1, -1):
... yield i
>>>
>>> ds = ds.ImageFolderDataset(image_folder_dataset_dir, sampler=ReverseSampler())
"""
def __init__(self, num_samples=None):
super().__init__(num_samples)
self.dataset_size = 0
self.child_sampler = None
self.num_samples = num_samples
def __iter__(self):
"""
User defined iterator, must be overridden.
_handshake is guaranteed to be called prior to iterator construction.
"""
raise NotImplementedError
def reset(self):
"""
Per repeat reset callback, override this method if necessary
"""
# Initialization handshake callback
# Do not override this method!
def _handshake(self, ds_size, num_samples):
self.dataset_size = ds_size
self.num_samples = num_samples
# Indices fetcher
# Do not override this method!
# pylint: disable=missing-docstring
def _get_indices(self):
sampler_iter = iter(self)
ret = []
for _ in range(self.num_samples):
try:
idx = next(sampler_iter)
ret.append(idx)
except StopIteration:
break
indices = np.array(ret)
if indices.dtype == object:
raise RuntimeError("Fetched indices can not be converted to a valid ndarray.")
return indices
# Instance fetcher
# Do not override this method!
def parse(self):
""" Parse the sampler."""
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.PreBuiltSamplerObj(num_samples, self)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def add_child(self, sampler):
self.child_sampler = sampler
def get_child(self):
return self.child_sampler
def parse_child(self):
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.parse()
return c_child_sampler
def is_shuffled(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
def get_num_samples(self):
if self.num_samples is None:
return None
return self._get_indices().size
[docs]class DistributedSampler(BuiltinSampler):
"""
A sampler that accesses a shard of the dataset, it helps divide dataset into multi-subset for distributed training.
Args:
num_shards (int): Number of shards to divide the dataset into.
shard_id (int): Shard ID of the current shard, which should within the range of [0, num_shards-1].
shuffle (bool, optional): If True, the indices are shuffled, otherwise it will not be shuffled(default=True).
num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements).
offset(int, optional): The starting shard ID where the elements in the dataset are sent to, which
should be no more than num_shards. This parameter is only valid when a ConcatDataset takes
a DistributedSampler as its sampler. It will affect the number of samples of per shard
(default=-1, which means each shard has same number of samples).
Examples:
>>> # creates a distributed sampler with 10 shards in total. This shard is shard 5.
>>> sampler = ds.DistributedSampler(10, 5)
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
Raises:
TypeError: If num_shards is not an integer value.
TypeError: If shard_id is not an integer value.
TypeError: If shuffle is not a boolean value.
TypeError: If num_samples is not an integer value.
TypeError: If offset is not an integer value.
ValueError: If num_samples is a negative value.
RuntimeError: If num_shards is not a positive value.
RuntimeError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards.
RuntimeError: If offset is greater than num_shards.
"""
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
if not isinstance(num_shards, int):
raise TypeError("num_shards must be integer but was: {}.".format(num_shards))
if not isinstance(shard_id, int):
raise TypeError("shard_id must be integer but was: {}.".format(shard_id))
if not isinstance(shuffle, bool):
raise TypeError("shuffle must be a boolean value but was: {}.".format(shuffle))
if num_samples is not None:
if not isinstance(num_samples, int):
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
if num_samples < 0 or num_samples > validator.INT64_MAX:
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
.format(0, validator.INT64_MAX))
if not isinstance(offset, int):
raise TypeError("offset must be integer but was: {}.".format(offset))
self.num_shards = num_shards
self.shard_id = shard_id
self.shuffle = shuffle
self.seed = 0
self.offset = offset
super().__init__(num_samples)
def parse(self):
""" Parse the sampler."""
num_samples = self.num_samples if self.num_samples is not None else 0
shuffle = self.shuffle if self.shuffle is not None else True
offset = self.offset if self.offset is not None else -1
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
self.seed += 1
c_sampler = cde.DistributedSamplerObj(self.num_shards, self.shard_id,
shuffle, num_samples, self.seed, offset, True)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def parse_for_minddataset(self):
""" Parse the sampler for MindRecord."""
num_samples = self.num_samples if self.num_samples is not None else 0
shuffle = self.shuffle if self.shuffle is not None else True
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, shuffle,
self.seed, num_samples, self.offset)
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return self.shuffle
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return self.num_shards > 1
return self.child_sampler.is_sharded()
def set_offset(self, offset):
self.offset = offset
return self
[docs]class PKSampler(BuiltinSampler):
"""
Samples K elements for each P class in the dataset.
Args:
num_val (int): Number of elements to sample for each class.
num_class (int, optional): Number of classes to sample (default=None, sample all classes).
The parameter does not supported to specify currently.
shuffle (bool, optional): If True, the class IDs are shuffled, otherwise it will not be
shuffled (default=False).
class_column (str, optional): Name of column with class labels for MindDataset (default='label').
num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements).
Examples:
>>> # creates a PKSampler that will get 3 samples from every class.
>>> sampler = ds.PKSampler(3)
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
Raises:
TypeError: If shuffle is not a boolean value.
TypeError: If class_column is not a str value.
TypeError: If num_samples is not an integer value.
NotImplementedError: If num_class is not None.
RuntimeError: If num_val is not a positive value.
ValueError: If num_samples is a negative value.
"""
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
if not isinstance(num_val, int):
raise TypeError("num_val must be integer but was: {}.".format(num_val))
if num_class is not None:
raise NotImplementedError("Not supported to specify num_class for PKSampler.")
if not isinstance(shuffle, bool):
raise TypeError("shuffle must be a boolean value but was: {}.".format(shuffle))
if not isinstance(class_column, str):
raise TypeError("class_column must be a str value but was: {}.".format(class_column))
if num_samples is not None:
if not isinstance(num_samples, int):
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
if num_samples < 0 or num_samples > validator.INT64_MAX:
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
.format(0, validator.INT64_MAX))
self.num_val = num_val
self.shuffle = shuffle
self.class_column = class_column # work for minddataset
super().__init__(num_samples)
def parse(self):
""" Parse the sampler."""
num_samples = self.num_samples if self.num_samples is not None else 0
shuffle = self.shuffle if self.shuffle is not None else False
c_sampler = cde.PKSamplerObj(self.num_val, shuffle, num_samples)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return self.shuffle
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
def parse_for_minddataset(self):
"""Parse the sampler for MindRecord."""
if not self.class_column or not isinstance(self.class_column, str):
raise ValueError("class_column should be a not empty string value, \
but got class_column: {}.".format(self.class_column))
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples)
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
[docs]class RandomSampler(BuiltinSampler):
"""
Samples the elements randomly.
Args:
replacement (bool, optional): If True, put the sample ID back for the next draw (default=False).
num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
Examples:
>>> # creates a RandomSampler
>>> sampler = ds.RandomSampler()
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
Raises:
TypeError: If replacement is not a boolean value.
TypeError: If num_samples is not an integer value.
ValueError: If num_samples is a negative value.
"""
def __init__(self, replacement=False, num_samples=None):
if not isinstance(replacement, bool):
raise TypeError("replacement must be a boolean value but was: {}.".format(replacement))
if num_samples is not None:
if not isinstance(num_samples, int):
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
if num_samples < 0 or num_samples > validator.INT64_MAX:
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
.format(0, validator.INT64_MAX))
self.deterministic = False
self.replacement = replacement
self.reshuffle_each_epoch = True
super().__init__(num_samples)
def parse(self):
""" Parse the sampler."""
num_samples = self.num_samples if self.num_samples is not None else 0
replacement = self.replacement if self.replacement is not None else False
c_sampler = cde.RandomSamplerObj(replacement, num_samples, self.reshuffle_each_epoch)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def parse_for_minddataset(self):
"""Parse the sampler for MindRecord."""
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
return True
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
[docs]class SequentialSampler(BuiltinSampler):
"""
Samples the dataset elements sequentially that is equivalent to not using a sampler.
Args:
start_index (int, optional): Index to start sampling at. (default=None, start at first ID)
num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
Examples:
>>> # creates a SequentialSampler
>>> sampler = ds.SequentialSampler()
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
Raises:
TypeError: If start_index is not an integer value.
TypeError: If num_samples is not an integer value.
RuntimeError: If start_index is a negative value.
ValueError: If num_samples is a negative value.
"""
def __init__(self, start_index=None, num_samples=None):
if start_index is not None and not isinstance(start_index, int):
raise TypeError("start_index must be integer but was: {}.".format(start_index))
if num_samples is not None:
if not isinstance(num_samples, int):
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
if num_samples < 0 or num_samples > validator.INT64_MAX:
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
.format(0, validator.INT64_MAX))
self.start_index = start_index
super().__init__(num_samples)
def parse(self):
""" Parse the sampler."""
start_index = self.start_index if self.start_index is not None else 0
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.SequentialSamplerObj(start_index, num_samples)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def parse_for_minddataset(self):
"""Parse the sampler for MindRecord."""
start_index = self.start_index if self.start_index is not None else 0
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index)
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
[docs]class SubsetSampler(BuiltinSampler):
"""
Samples the elements from a sequence of indices.
Args:
indices (Any iterable Python object but string): A sequence of indices.
num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
Examples:
>>> indices = [0, 1, 2, 3, 4, 5]
>>>
>>> # creates a SubsetSampler, will sample from the provided indices
>>> sampler = ds.SubsetSampler(indices)
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
Raises:
TypeError: If type of indices element is not a number.
TypeError: If num_samples is not an integer value.
ValueError: If num_samples is a negative value.
"""
def __init__(self, indices, num_samples=None):
def _get_sample_ids_as_list(sampler, number_of_samples=None):
if number_of_samples is None:
return list(sampler)
if isinstance(sampler, list):
return sampler[:number_of_samples]
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
if num_samples is not None:
if not isinstance(num_samples, int):
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
if num_samples < 0 or num_samples > validator.INT64_MAX:
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
.format(0, validator.INT64_MAX))
if not isinstance(indices, str) and validator.is_iterable(indices):
indices = _get_sample_ids_as_list(indices, num_samples)
elif isinstance(indices, int):
indices = [indices]
else:
raise TypeError('Unsupported sampler object of type ({})'.format(type(indices)))
for i, item in enumerate(indices):
if not isinstance(item, (int, np.integer)):
raise TypeError("SubsetSampler: Type of indices element must be int, "
"but got list[{}]: {}, type: {}.".format(i, item, type(item)))
self.indices = indices
super().__init__(num_samples)
def parse(self):
""" Parse the sampler."""
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.SubsetSamplerObj(self.indices, num_samples)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
return False
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
def parse_for_minddataset(self):
"""Parse the sampler for MindRecord."""
c_sampler = cde.MindrecordSubsetSampler(self.indices)
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
def get_num_samples(self):
num_samples = super().get_num_samples()
if num_samples is None:
return len(self.indices)
return min(len(self.indices), num_samples)
[docs]class SubsetRandomSampler(SubsetSampler):
"""
Samples the elements randomly from a sequence of indices.
Args:
indices (Any iterable Python object but string): A sequence of indices.
num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
Examples:
>>> indices = [0, 1, 2, 3, 7, 88, 119]
>>>
>>> # create a SubsetRandomSampler, will sample from the provided indices
>>> sampler = ds.SubsetRandomSampler(indices)
>>> data = ds.ImageFolderDataset(image_folder_dataset_dir, num_parallel_workers=8, sampler=sampler)
Raises:
TypeError: If type of indices element is not a number.
TypeError: If num_samples is not an integer value.
ValueError: If num_samples is a negative value.
"""
def parse(self):
""" Parse the sampler."""
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.SubsetRandomSamplerObj(self.indices, num_samples)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
return True
def parse_for_minddataset(self):
"""Parse the sampler for MindRecord."""
c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed())
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
class IterSampler(Sampler):
"""
User provided an iterable object without inheriting from our Sampler class.
Note:
This class exists to allow handshake logic between dataset operators and user defined samplers.
By constructing this object we avoid the user having to inherit from our Sampler class.
Args:
sampler (iterable object): an user defined iterable object.
num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
Examples:
>>> class MySampler:
... def __iter__(self):
... for i in range(99, -1, -1):
... yield i
>>> # creates an IterSampler
>>> sampler = ds.IterSampler(sampler=MySampler())
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
"""
def __init__(self, sampler, num_samples=None):
if num_samples is None:
num_samples = len(list(sampler))
super().__init__(num_samples=num_samples)
self.sampler = sampler
def __iter__(self):
return iter(self.sampler)
[docs]class WeightedRandomSampler(BuiltinSampler):
"""
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
Args:
weights (list[float, int]): A sequence of weights, not necessarily summing up to 1.
num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
replacement (bool): If True, put the sample ID back for the next draw (default=True).
Examples:
>>> weights = [0.9, 0.01, 0.4, 0.8, 0.1, 0.1, 0.3]
>>>
>>> # creates a WeightedRandomSampler that will sample 4 elements without replacement
>>> sampler = ds.WeightedRandomSampler(weights, 4)
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
Raises:
TypeError: If type of weights element is not a number.
TypeError: If num_samples is not an integer value.
TypeError: If replacement is not a boolean value.
RuntimeError: If `weights` is empty or all zero.
ValueError: If num_samples is a negative value.
"""
def __init__(self, weights, num_samples=None, replacement=True):
if not isinstance(weights, list):
weights = [weights]
for ind, w in enumerate(weights):
if not isinstance(w, numbers.Number):
raise TypeError("type of weights element must be number, "
"but got w[{}]: {}, type: {}.".format(ind, w, type(w)))
if num_samples is not None:
if not isinstance(num_samples, int):
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
if num_samples < 0 or num_samples > validator.INT64_MAX:
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
.format(0, validator.INT64_MAX))
if not isinstance(replacement, bool):
raise TypeError("replacement must be a boolean value but was: {}.".format(replacement))
self.weights = weights
self.replacement = replacement
super().__init__(num_samples)
def parse(self):
""" Parse the sampler."""
num_samples = self.num_samples if self.num_samples is not None else 0
replacement = self.replacement if self.replacement is not None else True
c_sampler = cde.WeightedRandomSamplerObj(self.weights, num_samples, replacement)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
return True
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()