# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The sampler module provides several samplers to generate data from datasets.
The provided samplers include: DistributedSampler, PKSampler, RandomSampler,
SequentialSampler, SubsetRandomSampler, and WeightedRandomSampler.
Users can also define a custom sampler by extending from the Sampler class.
"""
import numbers
import numpy as np
import mindspore._c_dataengine as cde
import mindspore.dataset as ds
from ..core import validator_helpers as validator
def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
"""
Create sampler based on user input.
Args:
num_samples (int): Number of samples.
input_sampler (Union[Iterable, Sampler]): Sampler from user.
shuffle (bool): Shuffle.
num_shards (int): Number of shard for sharding.
shard_id (int): Shard ID.
Returns:
Sampler, sampler selected based on user input.
"""
if input_sampler is not None:
# If the user provided a sampler, then it doesn't matter what the other args are because
# we are being asked specifically to use the given sampler.
# That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
# be None. Consider this example:
# sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
# data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
# In this case, the user has given different sample-related arguments that contradict each other.
# To prevent this, only allow the user to manually specify the sampler if those arguments are all None
if (isinstance(input_sampler, BuiltinSampler) and
(any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
raise ValueError(
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
if isinstance(input_sampler, BuiltinSampler):
return input_sampler
if not isinstance(input_sampler, str) and isinstance(input_sampler, (np.ndarray, list, tuple)):
return SubsetSampler(input_sampler, num_samples)
if not isinstance(input_sampler, str) and validator.is_iterable(input_sampler):
# in this case, the user passed in their own sampler object that's not of type BuiltinSampler
return IterSampler(input_sampler, num_samples)
if isinstance(input_sampler, int):
return SubsetSampler([input_sampler])
raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
if shuffle is None:
if num_shards is not None:
# If shuffle is not specified, sharding enabled, use distributed random sampler
shuffle = True
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle is not specified, sharding disabled, use random sampler
if num_samples is not None and num_samples != 0:
return RandomSampler(replacement=True, num_samples=num_samples)
return RandomSampler(num_samples=num_samples)
if shuffle is True:
if num_shards is not None:
# If shuffle enabled, sharding enabled, use distributed random sampler
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle enabled, sharding disabled, use random sampler
if num_samples is not None:
return RandomSampler(replacement=True, num_samples=num_samples)
return RandomSampler(num_samples=num_samples)
if num_shards is not None:
# If shuffle disabled, sharding enabled, use distributed sequential sampler
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle disabled, sharding disabled, use sequential sampler
return SequentialSampler(num_samples=num_samples)
class BuiltinSampler:
"""
Base class for BuiltinSampler.
User should not extend this class.
"""
def __init__(self, num_samples=None):
self.child_sampler = None
self.num_samples = num_samples
def parse(self):
""" Parse the sampler."""
def add_child(self, sampler):
"""
Add a sub-sampler for given sampler. The parent will receive all data from the
output of sub-sampler sampler and apply its sample logic to return new samples.
Args:
sampler (Sampler): Object used to choose samples from the dataset. Only builtin
samplers(DistributedSampler, PKSampler, RandomSampler, SequentialSampler,
SubsetRandomSampler, WeightedRandomSampler) are supported.
Examples:
>>> sampler = ds.SequentialSampler(start_index=0, num_samples=3)
>>> sampler.add_child(ds.RandomSampler(num_samples=4))
>>> dataset = ds.Cifar10Dataset(cifar10_dataset_dir, sampler=sampler)
"""
if self.child_sampler is not None:
raise RuntimeError("Cannot add child sampler, this sampler already has a child.")
self.child_sampler = sampler
def get_child(self):
"""
Get the child sampler of given sampler.
Returns:
Sampler, The child sampler of given sampler.
Examples:
>>> sampler = ds.SequentialSampler(start_index=0, num_samples=3)
>>> sampler.add_child(ds.RandomSampler(num_samples=2))
>>> child_sampler = sampler.get_child()
"""
return self.child_sampler
def parse_child(self):
""" Parse the child sampler. """
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.parse()
return c_child_sampler
def parse_child_for_minddataset(self):
""" Parse the child sampler for MindRecord. """
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.parse_for_minddataset()
return c_child_sampler
def is_shuffled(self):
""" Not implemented. """
raise NotImplementedError("Sampler must implement is_shuffled.")
def is_sharded(self):
""" Not implemented. """
raise NotImplementedError("Sampler must implement is_sharded.")
def get_num_samples(self):
"""
Get num_samples value of the current sampler instance.
This parameter can be optionally passed in when defining the Sampler. Default: None.
This method will return the num_samples value.
If the current sampler has child samplers,
it will continue to access the child samplers and process the obtained value according to certain rules.
The following table shows the various possible combinations, and the final results returned.
.. list-table::
:widths: 25 25 25 25
:header-rows: 1
* - child sampler
- num_samples
- child_samples
- result
* - T
- x
- y
- min(x, y)
* - T
- x
- None
- x
* - T
- None
- y
- y
* - T
- None
- None
- None
* - None
- x
- n/a
- x
* - None
- None
- n/a
- None
Returns:
int, the number of samples, or None.
Examples:
>>> sampler = ds.SequentialSampler(start_index=0, num_samples=3)
>>> num_samplers = sampler.get_num_samples()
"""
if self.child_sampler is not None:
child_samples = self.child_sampler.get_num_samples()
if self.num_samples is not None:
if child_samples is not None:
return min(self.num_samples, child_samples)
return self.num_samples
return child_samples
return self.num_samples
class Sampler(BuiltinSampler):
"""
Base class for user defined sampler.
A user defined sampler can be used with any existing dataset with sampler support.
A required _iter_() method should by overridden by the user for sample index generation.
An optional reset() method can be overridden for per repeat reset,
dataset_size and num_samples will be set by dataset once a dataset iterator is created.
Examples:
>>> class ReverseSampler(ds.Sampler):
... def __iter__(self):
... for i in range(self.dataset_size - 1, -1, -1):
... yield i
>>>
>>> ds = ds.ImageFolderDataset(image_folder_dataset_dir, sampler=ReverseSampler())
"""
def __init__(self, num_samples=None):
super().__init__(num_samples)
self.dataset_size = 0
self.child_sampler = None
self.num_samples = num_samples
def __iter__(self):
"""
User defined iterator, must be overridden.
_handshake is guaranteed to be called prior to iterator construction.
"""
raise NotImplementedError
def reset(self):
"""
Per repeat reset callback, override this method if necessary
"""
# Initialization handshake callback
# Do not override this method!
def _handshake(self, ds_size, num_samples):
self.dataset_size = ds_size
self.num_samples = num_samples
# Indices fetcher
# Do not override this method!
# pylint: disable=missing-docstring
def _get_indices(self):
sampler_iter = iter(self)
ret = []
for _ in range(self.num_samples):
try:
idx = next(sampler_iter)
ret.append(idx)
except StopIteration:
break
indices = np.array(ret)
if indices.dtype == object:
raise RuntimeError("Fetched indices can not be converted to a valid ndarray.")
return indices
# Instance fetcher
# Do not override this method!
def parse(self):
""" Parse the sampler."""
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.PreBuiltSamplerObj(num_samples, self)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def add_child(self, sampler):
self.child_sampler = sampler
def get_child(self):
return self.child_sampler
def parse_child(self):
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.parse()
return c_child_sampler
def is_shuffled(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
def get_num_samples(self):
if self.num_samples is None:
return None
return self._get_indices().size
[文档]class DistributedSampler(BuiltinSampler):
"""
A sampler that accesses a shard of the dataset, it helps divide dataset into multi-subset for distributed training.
Args:
num_shards (int): Number of shards to divide the dataset into.
shard_id (int): Shard ID of the current shard, which should within the range of [0, `num_shards` - 1].
shuffle (bool, optional): If True, the indices are shuffled, otherwise it will not be shuffled. Default: True.
num_samples (int, optional): The number of samples to draw. Default: None, which means sample all elements.
offset(int, optional): The starting shard ID where the elements in the dataset are sent to, which
should be no more than `num_shards` . This parameter is only valid when a ConcatDataset takes
a DistributedSampler as its sampler. It will affect the number of samples of per shard.
Default: -1, which means each shard has the same number of samples.
Raises:
TypeError: If `num_shards` is not of type int.
TypeError: If `shard_id` is not of type int.
TypeError: If `shuffle` is not of type bool.
TypeError: If `num_samples` is not of type int.
TypeError: If `offset` is not of type int.
ValueError: If `num_samples` is a negative value.
RuntimeError: If `num_shards` is not a positive value.
RuntimeError: If `shard_id` is smaller than 0 or equal to `num_shards` or larger than `num_shards` .
RuntimeError: If `offset` is greater than `num_shards` .
Examples:
>>> # creates a distributed sampler with 10 shards in total. This shard is shard 5.
>>> sampler = ds.DistributedSampler(10, 5)
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
"""
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
if not isinstance(num_shards, int):
raise TypeError("num_shards must be integer but was: {}.".format(num_shards))
if not isinstance(shard_id, int):
raise TypeError("shard_id must be integer but was: {}.".format(shard_id))
if not isinstance(shuffle, bool):
raise TypeError("shuffle must be a boolean value but was: {}.".format(shuffle))
if num_samples is not None:
if not isinstance(num_samples, int):
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
if num_samples < 0 or num_samples > validator.INT64_MAX:
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
.format(0, validator.INT64_MAX))
if not isinstance(offset, int):
raise TypeError("offset must be integer but was: {}.".format(offset))
self.num_shards = num_shards
self.shard_id = shard_id
self.shuffle = shuffle
# get seed in distributed scenario
# Example 1. if user set seeds by ds.config.set_seed(4321), then seed 4321 is used
# Example 2. if user does not set the seed, then existing or default seed (like 5489) is used
self.seed = ds.config.get_seed()
self.offset = offset
super().__init__(num_samples)
def parse(self):
""" Parse the sampler."""
num_samples = self.num_samples if self.num_samples is not None else 0
shuffle = self.shuffle if self.shuffle is not None else True
offset = self.offset if self.offset is not None else -1
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
self.seed += 1
c_sampler = cde.DistributedSamplerObj(self.num_shards, self.shard_id,
shuffle, num_samples, self.seed, offset, True)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def parse_for_minddataset(self):
""" Parse the sampler for MindRecord."""
num_samples = self.num_samples if self.num_samples is not None else 0
shuffle = self.shuffle if self.shuffle is not None else True
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, shuffle,
self.seed, num_samples, self.offset)
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
c_sampler.set_num_samples(num_samples)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return self.shuffle
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return self.num_shards > 1
return self.child_sampler.is_sharded()
def set_offset(self, offset):
self.offset = offset
return self
[文档]class PKSampler(BuiltinSampler):
"""
Samples K elements for each P class in the dataset.
Args:
num_val (int): Number of elements to sample for each class.
num_class (int, optional): Number of classes to sample. Default: None, sample all classes.
The parameter does not support to specify currently.
shuffle (bool, optional): If True, the class IDs are shuffled, otherwise it will not be
shuffled. Default: False.
class_column (str, optional): Name of column with class labels for MindDataset. Default: 'label'.
num_samples (int, optional): The number of samples to draw. Default: None, which means sample all elements.
Raises:
TypeError: If `shuffle` is not of type bool.
TypeError: If `class_column` is not of type str.
TypeError: If `num_samples` is not of type int.
NotImplementedError: If `num_class` is not None.
RuntimeError: If `num_val` is not a positive value.
ValueError: If `num_samples` is a negative value.
Examples:
>>> # creates a PKSampler that will get 3 samples from every class.
>>> sampler = ds.PKSampler(3)
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
"""
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
if not isinstance(num_val, int):
raise TypeError("num_val must be integer but was: {}.".format(num_val))
if num_class is not None:
raise NotImplementedError("Not supported to specify num_class for PKSampler.")
if not isinstance(shuffle, bool):
raise TypeError("shuffle must be a boolean value but was: {}.".format(shuffle))
if not isinstance(class_column, str):
raise TypeError("class_column must be a str value but was: {}.".format(class_column))
if num_samples is not None:
if not isinstance(num_samples, int):
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
if num_samples < 0 or num_samples > validator.INT64_MAX:
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
.format(0, validator.INT64_MAX))
self.num_val = num_val
self.shuffle = shuffle
self.class_column = class_column # work for minddataset
super().__init__(num_samples)
def parse(self):
""" Parse the sampler."""
num_samples = self.num_samples if self.num_samples is not None else 0
shuffle = self.shuffle if self.shuffle is not None else False
c_sampler = cde.PKSamplerObj(self.num_val, shuffle, num_samples)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return self.shuffle
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
def parse_for_minddataset(self):
"""Parse the sampler for MindRecord."""
if not self.class_column or not isinstance(self.class_column, str):
raise ValueError("class_column should be a not empty string value, \
but got class_column: {}.".format(self.class_column))
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples)
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
c_sampler.set_num_samples(num_samples)
return c_sampler
[文档]class RandomSampler(BuiltinSampler):
"""
Samples the elements randomly.
Args:
replacement (bool, optional): If True, put the sample ID back for the next draw. Default: False.
num_samples (int, optional): Number of elements to sample. Default: None, which means sample all elements.
Raises:
TypeError: If `replacement` is not of type bool.
TypeError: If `num_samples` is not of type int.
ValueError: If `num_samples` is a negative value.
Examples:
>>> # creates a RandomSampler
>>> sampler = ds.RandomSampler()
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
"""
def __init__(self, replacement=False, num_samples=None):
if not isinstance(replacement, bool):
raise TypeError("replacement must be a boolean value but was: {}.".format(replacement))
if num_samples is not None:
if not isinstance(num_samples, int):
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
if num_samples < 0 or num_samples > validator.INT64_MAX:
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
.format(0, validator.INT64_MAX))
self.deterministic = False
self.replacement = replacement
self.reshuffle_each_epoch = True
super().__init__(num_samples)
def parse(self):
""" Parse the sampler."""
num_samples = self.num_samples if self.num_samples is not None else 0
replacement = self.replacement if self.replacement is not None else False
c_sampler = cde.RandomSamplerObj(replacement, num_samples, self.reshuffle_each_epoch)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def parse_for_minddataset(self):
"""Parse the sampler for MindRecord."""
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
c_sampler.set_num_samples(num_samples)
return c_sampler
def is_shuffled(self):
return True
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
[文档]class SequentialSampler(BuiltinSampler):
"""
Samples the dataset elements sequentially that is equivalent to not using a sampler.
Args:
start_index (int, optional): Index to start sampling at. Default: None, start at first ID.
num_samples (int, optional): Number of elements to sample. Default: None, which means sample all elements.
Raises:
TypeError: If `start_index` is not of type int.
TypeError: If `num_samples` is not of type int.
RuntimeError: If `start_index` is a negative value.
ValueError: If `num_samples` is a negative value.
Examples:
>>> # creates a SequentialSampler
>>> sampler = ds.SequentialSampler()
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
"""
def __init__(self, start_index=None, num_samples=None):
if start_index is not None and not isinstance(start_index, int):
raise TypeError("start_index must be integer but was: {}.".format(start_index))
if num_samples is not None:
if not isinstance(num_samples, int):
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
if num_samples < 0 or num_samples > validator.INT64_MAX:
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
.format(0, validator.INT64_MAX))
self.start_index = start_index
super().__init__(num_samples)
def parse(self):
""" Parse the sampler."""
start_index = self.start_index if self.start_index is not None else 0
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.SequentialSamplerObj(start_index, num_samples)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def parse_for_minddataset(self):
"""Parse the sampler for MindRecord."""
start_index = self.start_index if self.start_index is not None else 0
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index)
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
c_sampler.set_num_samples(num_samples)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
[文档]class SubsetSampler(BuiltinSampler):
"""
Samples the elements from a sequence of indices.
Args:
indices (Iterable): A sequence of indices (Any iterable Python object but string).
num_samples (int, optional): Number of elements to sample. Default: None, which means sample all elements.
Raises:
TypeError: If elements of `indices` are not of type number.
TypeError: If `num_samples` is not of type int.
ValueError: If `num_samples` is a negative value.
Examples:
>>> indices = [0, 1, 2, 3, 4, 5]
>>>
>>> # creates a SubsetSampler, will sample from the provided indices
>>> sampler = ds.SubsetSampler(indices)
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
"""
def __init__(self, indices, num_samples=None):
def _get_sample_ids_as_list(sampler, number_of_samples=None):
if number_of_samples is None:
return list(sampler)
if isinstance(sampler, list):
return sampler[:number_of_samples]
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
if num_samples is not None:
if not isinstance(num_samples, int):
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
if num_samples < 0 or num_samples > validator.INT64_MAX:
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
.format(0, validator.INT64_MAX))
if not isinstance(indices, str) and validator.is_iterable(indices):
indices = _get_sample_ids_as_list(indices, num_samples)
elif isinstance(indices, int):
indices = [indices]
else:
raise TypeError('Unsupported sampler object of type ({})'.format(type(indices)))
for i, item in enumerate(indices):
if not isinstance(item, (int, np.integer)):
raise TypeError("SubsetSampler: Type of indices element must be int, "
"but got list[{}]: {}, type: {}.".format(i, item, type(item)))
self.indices = indices
super().__init__(num_samples)
def parse(self):
""" Parse the sampler."""
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.SubsetSamplerObj(self.indices, num_samples)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
return False
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
def parse_for_minddataset(self):
"""Parse the sampler for MindRecord."""
c_sampler = cde.MindrecordSubsetSampler(self.indices)
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
c_sampler.set_num_samples(self.get_num_samples())
return c_sampler
def get_num_samples(self):
num_samples = super().get_num_samples()
if num_samples is None:
return len(self.indices)
return min(len(self.indices), num_samples)
[文档]class SubsetRandomSampler(SubsetSampler):
"""
Samples the elements randomly from a sequence of indices.
Args:
indices (Iterable): A sequence of indices (Any iterable Python object but string).
num_samples (int, optional): Number of elements to sample. Default: None, which means sample all elements.
Raises:
TypeError: If elements of `indices` are not of type number.
TypeError: If `num_samples` is not of type int.
ValueError: If `num_samples` is a negative value.
Examples:
>>> indices = [0, 1, 2, 3, 7, 88, 119]
>>>
>>> # create a SubsetRandomSampler, will sample from the provided indices
>>> sampler = ds.SubsetRandomSampler(indices)
>>> data = ds.ImageFolderDataset(image_folder_dataset_dir, num_parallel_workers=8, sampler=sampler)
"""
def parse(self):
""" Parse the sampler."""
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.SubsetRandomSamplerObj(self.indices, num_samples)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
return True
def parse_for_minddataset(self):
"""Parse the sampler for MindRecord."""
c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed())
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
c_sampler.set_num_samples(self.get_num_samples())
return c_sampler
class IterSampler(Sampler):
"""
User provided an iterable object without inheriting from our Sampler class.
Note:
This class exists to allow handshake logic between dataset operations and user defined samplers.
By constructing this object we avoid the user having to inherit from our Sampler class.
Args:
sampler (iterable object): an user defined iterable object.
num_samples (int, optional): Number of elements to sample. Default: None, which means sample all elements.
Examples:
>>> class MySampler:
... def __iter__(self):
... for i in range(99, -1, -1):
... yield i
>>> # creates an IterSampler
>>> sampler = ds.IterSampler(sampler=MySampler())
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
"""
def __init__(self, sampler, num_samples=None):
if num_samples is None:
num_samples = len(list(sampler))
super().__init__(num_samples=num_samples)
self.sampler = sampler
def __iter__(self):
return iter(self.sampler)
[文档]class WeightedRandomSampler(BuiltinSampler):
"""
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
Args:
weights (list[float, int]): A sequence of weights, not necessarily summing up to 1.
num_samples (int, optional): Number of elements to sample. Default: None, which means sample all elements.
replacement (bool): If True, put the sample ID back for the next draw. Default: True.
Raises:
TypeError: If elements of `weights` are not of type number.
TypeError: If `num_samples` is not of type int.
TypeError: If `replacement` is not of type bool.
RuntimeError: If `weights` is empty or all zero.
ValueError: If `num_samples` is a negative value.
Examples:
>>> weights = [0.9, 0.01, 0.4, 0.8, 0.1, 0.1, 0.3]
>>>
>>> # creates a WeightedRandomSampler that will sample 4 elements without replacement
>>> sampler = ds.WeightedRandomSampler(weights, 4)
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
"""
def __init__(self, weights, num_samples=None, replacement=True):
if not isinstance(weights, list):
weights = [weights]
for ind, w in enumerate(weights):
if not isinstance(w, numbers.Number):
raise TypeError("type of weights element must be number, "
"but got w[{}]: {}, type: {}.".format(ind, w, type(w)))
if num_samples is not None:
if not isinstance(num_samples, int):
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
if num_samples < 0 or num_samples > validator.INT64_MAX:
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
.format(0, validator.INT64_MAX))
if not isinstance(replacement, bool):
raise TypeError("replacement must be a boolean value but was: {}.".format(replacement))
self.weights = weights
self.replacement = replacement
super().__init__(num_samples)
def parse(self):
""" Parse the sampler."""
num_samples = self.num_samples if self.num_samples is not None else 0
replacement = self.replacement if self.replacement is not None else True
c_sampler = cde.WeightedRandomSamplerObj(self.weights, num_samples, replacement)
c_child_sampler = self.parse_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
return True
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()