Source code for mindspore.dataset.engine.samplers

# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The sampler module provides several samplers to generate data from datasets.
The provided samplers include: DistributedSampler, PKSampler, RandomSampler,
SequentialSampler, SubsetRandomSampler, and WeightedRandomSampler.
Users can also define a custom sampler by extending from the Sampler class.
"""

import numbers
import numpy as np
import mindspore._c_dataengine as cde
import mindspore.dataset as ds
from ..core import validator_helpers as validator


def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
    """
    Create sampler based on user input.

    Args:
        num_samples (int): Number of samples.
        input_sampler (Union[Iterable, Sampler]): Sampler from user.
        shuffle (bool): Shuffle.
        num_shards (int): Number of shard for sharding.
        shard_id (int): Shard ID.

    Returns:
        Sampler, sampler selected based on user input.
    """

    if input_sampler is not None:
        # If the user provided a sampler, then it doesn't matter what the other args are because
        # we are being asked specifically to use the given sampler.
        # That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
        # be None. Consider this example:
        #     sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
        #     data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
        # In this case, the user has given different sample-related arguments that contradict each other.
        # To prevent this, only allow the user to manually specify the sampler if those arguments are all None
        if (isinstance(input_sampler, BuiltinSampler) and
                (any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
            raise ValueError(
                'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
                ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
        if isinstance(input_sampler, BuiltinSampler):
            return input_sampler
        if not isinstance(input_sampler, str) and isinstance(input_sampler, (np.ndarray, list)):
            return SubsetSampler(input_sampler, num_samples)
        if not isinstance(input_sampler, str) and validator.is_iterable(input_sampler):
            # in this case, the user passed in their own sampler object that's not of type BuiltinSampler
            return IterSampler(input_sampler, num_samples)
        if isinstance(input_sampler, int):
            return SubsetSampler([input_sampler])
        raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
    if shuffle is None:
        if num_shards is not None:
            # If shuffle is not specified, sharding enabled, use distributed random sampler
            shuffle = True
            return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
        # If shuffle is not specified, sharding disabled, use random sampler
        if num_samples is not None and num_samples != 0:
            return RandomSampler(replacement=True, num_samples=num_samples)
        return RandomSampler(num_samples=num_samples)
    if shuffle is True:
        if num_shards is not None:
            # If shuffle enabled, sharding enabled, use distributed random sampler
            return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
        # If shuffle enabled, sharding disabled, use random sampler
        if num_samples is not None:
            return RandomSampler(replacement=True, num_samples=num_samples)
        return RandomSampler(num_samples=num_samples)
    if num_shards is not None:
        # If shuffle disabled, sharding enabled, use distributed sequential sampler
        return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
    # If shuffle disabled, sharding disabled, use sequential sampler
    return SequentialSampler(num_samples=num_samples)


class BuiltinSampler:
    """
    Base class for BuiltinSampler.

    User should not extend this class.
    """

    def __init__(self, num_samples=None):
        self.child_sampler = None
        self.num_samples = num_samples

    def parse(self):
        """ Parse the sampler."""

    def add_child(self, sampler):
        """
        Add a sub-sampler for given sampler. The sub-sampler will receive all data from the
        output of parent sampler and apply its sample logic to return new samples.

        Args:
            sampler (Sampler): Object used to choose samples from the dataset. Only builtin
                samplers(DistributedSampler, PKSampler, RandomSampler, SequentialSampler,
                SubsetRandomSampler, WeightedRandomSampler) are supported.

        Examples:
            >>> sampler = ds.SequentialSampler(start_index=0, num_samples=3)
            >>> sampler.add_child(ds.RandomSampler(num_samples=2))
            >>> dataset = ds.Cifar10Dataset(cifar10_dataset_dir, sampler=sampler)
        """
        self.child_sampler = sampler

    def get_child(self):
        """ Get the child sampler. """
        return self.child_sampler

    def parse_child(self):
        """ Parse the child sampler. """
        c_child_sampler = None
        if self.child_sampler is not None:
            c_child_sampler = self.child_sampler.parse()
        return c_child_sampler

    def parse_child_for_minddataset(self):
        """ Parse the child sampler for MindRecord. """
        c_child_sampler = None
        if self.child_sampler is not None:
            c_child_sampler = self.child_sampler.parse_for_minddataset()
        return c_child_sampler

    def is_shuffled(self):
        """ Not implemented. """
        raise NotImplementedError("Sampler must implement is_shuffled.")

    def is_sharded(self):
        """ Not implemented. """
        raise NotImplementedError("Sampler must implement is_sharded.")

    def get_num_samples(self):
        """
        All samplers can contain a numeric num_samples value (or it can be set to None).
        A child sampler can exist or be None.
        If a child sampler exists, then the child sampler count can be a numeric value or None.
        These conditions impact the resultant sampler count that is used.
        The following table shows the possible results from calling this function.

        .. list-table::
           :widths: 25 25 25 25
           :header-rows: 1

           * - child sampler
             - num_samples
             - child_samples
             - result
           * - T
             - x
             - y
             - min(x, y)
           * - T
             - x
             - None
             - x
           * - T
             - None
             - y
             - y
           * - T
             - None
             - None
             - None
           * - None
             - x
             - n/a
             - x
           * - None
             - None
             - n/a
             - None

        Returns:
            int, the number of samples, or None.
        """
        if self.child_sampler is not None:
            child_samples = self.child_sampler.get_num_samples()
            if self.num_samples is not None:
                if child_samples is not None:
                    return min(self.num_samples, child_samples)

                return self.num_samples

            return child_samples

        return self.num_samples


class Sampler(BuiltinSampler):
    """
    Base class for user defined sampler.
    A user defined sampler can be used with any existing dataset with sampler support.

    A required  _iter_() method should by overridden by the user for sample index generation.
    An optional reset() method can be overridden for per repeat reset,

    dataset_size and num_samples will be set by dataset once a dataset iterator is created.

    Examples:
        >>> class ReverseSampler(ds.Sampler):
        ...     def __iter__(self):
        ...         for i in range(self.dataset_size - 1, -1, -1):
        ...             yield i
        >>>
        >>> ds = ds.ImageFolderDataset(image_folder_dataset_dir, sampler=ReverseSampler())
    """

    def __init__(self, num_samples=None):
        super().__init__(num_samples)
        self.dataset_size = 0
        self.child_sampler = None
        self.num_samples = num_samples

    def __iter__(self):
        """
        User defined iterator, must be overridden.
        _handshake is guaranteed to be called prior to iterator construction.
        """
        raise NotImplementedError

    def reset(self):
        """
        Per repeat reset callback, override this method if necessary
        """

    # Initialization handshake callback
    # Do not override this method!
    def _handshake(self, ds_size, num_samples):
        self.dataset_size = ds_size
        self.num_samples = num_samples

    # Indices fetcher
    # Do not override this method!
    # pylint: disable=missing-docstring
    def _get_indices(self):
        sampler_iter = iter(self)
        ret = []
        for _ in range(self.num_samples):
            try:
                idx = next(sampler_iter)
                ret.append(idx)
            except StopIteration:
                break
        indices = np.array(ret)
        if indices.dtype == object:
            raise RuntimeError("Fetched indices can not be converted to a valid ndarray.")
        return indices

    # Instance fetcher
    # Do not override this method!
    def parse(self):
        """ Parse the sampler."""
        num_samples = self.num_samples if self.num_samples is not None else 0
        c_sampler = cde.PreBuiltSamplerObj(num_samples, self)
        c_child_sampler = self.parse_child()
        c_sampler.add_child(c_child_sampler)
        return c_sampler

    def add_child(self, sampler):
        self.child_sampler = sampler

    def get_child(self):
        return self.child_sampler

    def parse_child(self):
        c_child_sampler = None
        if self.child_sampler is not None:
            c_child_sampler = self.child_sampler.parse()

        return c_child_sampler

    def is_shuffled(self):
        if self.child_sampler is None:
            return False

        return self.child_sampler.is_shuffled()

    def is_sharded(self):
        if self.child_sampler is None:
            return False

        return self.child_sampler.is_sharded()

    def get_num_samples(self):
        if self.num_samples is None:
            return None
        return self._get_indices().size


[docs]class DistributedSampler(BuiltinSampler): """ A sampler that accesses a shard of the dataset. Args: num_shards (int): Number of shards to divide the dataset into. shard_id (int): Shard ID of the current shard, which should within the range of [0, num_shards-1]. shuffle (bool, optional): If True, the indices are shuffled, otherwise it will not be shuffled(default=True). num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements). offset(int, optional): The starting shard ID where the elements in the dataset are sent to, which should be no more than num_shards. This parameter is only valid when a ConcatDataset takes a DistributedSampler as its sampler. It will affect the number of samples of per shard (default=-1, which means each shard has same number of samples). Examples: >>> # creates a distributed sampler with 10 shards in total. This shard is shard 5. >>> sampler = ds.DistributedSampler(10, 5) >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, ... num_parallel_workers=8, ... sampler=sampler) Raises: TypeError: If num_shards is not an integer value. TypeError: If shard_id is not an integer value. TypeError: If shuffle is not a boolean value. TypeError: If num_samples is not an integer value. TypeError: If offset is not an integer value. ValueError: If num_samples is a negative value. RuntimeError: If num_shards is not a positive value. RuntimeError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards. RuntimeError: If offset is greater than num_shards. """ def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1): if not isinstance(num_shards, int): raise TypeError("num_shards must be integer but was: {}.".format(num_shards)) if not isinstance(shard_id, int): raise TypeError("shard_id must be integer but was: {}.".format(shard_id)) if not isinstance(shuffle, bool): raise TypeError("shuffle must be a boolean value but was: {}.".format(shuffle)) if num_samples is not None: if not isinstance(num_samples, int): raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) if num_samples < 0 or num_samples > validator.INT64_MAX: raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" .format(0, validator.INT64_MAX)) if not isinstance(offset, int): raise TypeError("offset must be integer but was: {}.".format(offset)) self.num_shards = num_shards self.shard_id = shard_id self.shuffle = shuffle self.seed = 0 self.offset = offset super().__init__(num_samples)
[docs] def parse(self): """ Parse the sampler.""" num_samples = self.num_samples if self.num_samples is not None else 0 shuffle = self.shuffle if self.shuffle is not None else True offset = self.offset if self.offset is not None else -1 # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle self.seed += 1 c_sampler = cde.DistributedSamplerObj(self.num_shards, self.shard_id, shuffle, num_samples, self.seed, offset, True) c_child_sampler = self.parse_child() c_sampler.add_child(c_child_sampler) return c_sampler
[docs] def parse_for_minddataset(self): """ Parse the sampler for MindRecord.""" num_samples = self.num_samples if self.num_samples is not None else 0 shuffle = self.shuffle if self.shuffle is not None else True c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, shuffle, self.seed, num_samples, self.offset) c_child_sampler = self.parse_child_for_minddataset() c_sampler.add_child(c_child_sampler) return c_sampler
def is_shuffled(self): if self.child_sampler is None: return self.shuffle return self.child_sampler.is_shuffled() def is_sharded(self): if self.child_sampler is None: return self.num_shards > 1 return self.child_sampler.is_sharded() def set_offset(self, offset): self.offset = offset return self
[docs]class PKSampler(BuiltinSampler): """ Samples K elements for each P class in the dataset. Args: num_val (int): Number of elements to sample for each class. num_class (int, optional): Number of classes to sample (default=None, sample all classes). The parameter does not supported to specify currently. shuffle (bool, optional): If True, the class IDs are shuffled, otherwise it will not be shuffled (default=False). class_column (str, optional): Name of column with class labels for MindDataset (default='label'). num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements). Examples: >>> # creates a PKSampler that will get 3 samples from every class. >>> sampler = ds.PKSampler(3) >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, ... num_parallel_workers=8, ... sampler=sampler) Raises: TypeError: If shuffle is not a boolean value. TypeError: If class_column is not a str value. TypeError: If num_samples is not an integer value. NotImplementedError: If num_class is not None. RuntimeError: If num_val is not a positive value. ValueError: If num_samples is a negative value. """ def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None): if not isinstance(num_val, int): raise TypeError("num_val must be integer but was: {}.".format(num_val)) if num_class is not None: raise NotImplementedError("Not supported to specify num_class for PKSampler.") if not isinstance(shuffle, bool): raise TypeError("shuffle must be a boolean value but was: {}.".format(shuffle)) if not isinstance(class_column, str): raise TypeError("class_column must be a str value but was: {}.".format(class_column)) if num_samples is not None: if not isinstance(num_samples, int): raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) if num_samples < 0 or num_samples > validator.INT64_MAX: raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" .format(0, validator.INT64_MAX)) self.num_val = num_val self.shuffle = shuffle self.class_column = class_column # work for minddataset super().__init__(num_samples)
[docs] def parse(self): """ Parse the sampler.""" num_samples = self.num_samples if self.num_samples is not None else 0 shuffle = self.shuffle if self.shuffle is not None else False c_sampler = cde.PKSamplerObj(self.num_val, shuffle, num_samples) c_child_sampler = self.parse_child() c_sampler.add_child(c_child_sampler) return c_sampler
def is_shuffled(self): if self.child_sampler is None: return self.shuffle return self.child_sampler.is_shuffled() def is_sharded(self): if self.child_sampler is None: return False return self.child_sampler.is_sharded()
[docs] def parse_for_minddataset(self): """Parse the sampler for MindRecord.""" if not self.class_column or not isinstance(self.class_column, str): raise ValueError("class_column should be a not empty string value, \ but got class_column: {}.".format(self.class_column)) num_samples = self.num_samples if self.num_samples is not None else 0 c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples) c_child_sampler = self.parse_child_for_minddataset() c_sampler.add_child(c_child_sampler) return c_sampler
[docs]class RandomSampler(BuiltinSampler): """ Samples the elements randomly. Args: replacement (bool, optional): If True, put the sample ID back for the next draw (default=False). num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements). Examples: >>> # creates a RandomSampler >>> sampler = ds.RandomSampler() >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, ... num_parallel_workers=8, ... sampler=sampler) Raises: TypeError: If replacement is not a boolean value. TypeError: If num_samples is not an integer value. ValueError: If num_samples is a negative value. """ def __init__(self, replacement=False, num_samples=None): if not isinstance(replacement, bool): raise TypeError("replacement must be a boolean value but was: {}.".format(replacement)) if num_samples is not None: if not isinstance(num_samples, int): raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) if num_samples < 0 or num_samples > validator.INT64_MAX: raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" .format(0, validator.INT64_MAX)) self.deterministic = False self.replacement = replacement self.reshuffle_each_epoch = True super().__init__(num_samples)
[docs] def parse(self): """ Parse the sampler.""" num_samples = self.num_samples if self.num_samples is not None else 0 replacement = self.replacement if self.replacement is not None else False c_sampler = cde.RandomSamplerObj(replacement, num_samples, self.reshuffle_each_epoch) c_child_sampler = self.parse_child() c_sampler.add_child(c_child_sampler) return c_sampler
[docs] def parse_for_minddataset(self): """Parse the sampler for MindRecord.""" num_samples = self.num_samples if self.num_samples is not None else 0 c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch) c_child_sampler = self.parse_child_for_minddataset() c_sampler.add_child(c_child_sampler) return c_sampler
def is_shuffled(self): return True def is_sharded(self): if self.child_sampler is None: return False return self.child_sampler.is_sharded()
[docs]class SequentialSampler(BuiltinSampler): """ Samples the dataset elements sequentially that is equivalent to not using a sampler. Args: start_index (int, optional): Index to start sampling at. (default=None, start at first ID) num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements). Examples: >>> # creates a SequentialSampler >>> sampler = ds.SequentialSampler() >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, ... num_parallel_workers=8, ... sampler=sampler) Raises: TypeError: If start_index is not an integer value. TypeError: If num_samples is not an integer value. RuntimeError: If start_index is a negative value. ValueError: If num_samples is a negative value. """ def __init__(self, start_index=None, num_samples=None): if start_index is not None and not isinstance(start_index, int): raise TypeError("start_index must be integer but was: {}.".format(start_index)) if num_samples is not None: if not isinstance(num_samples, int): raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) if num_samples < 0 or num_samples > validator.INT64_MAX: raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" .format(0, validator.INT64_MAX)) self.start_index = start_index super().__init__(num_samples)
[docs] def parse(self): """ Parse the sampler.""" start_index = self.start_index if self.start_index is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0 c_sampler = cde.SequentialSamplerObj(start_index, num_samples) c_child_sampler = self.parse_child() c_sampler.add_child(c_child_sampler) return c_sampler
[docs] def parse_for_minddataset(self): """Parse the sampler for MindRecord.""" start_index = self.start_index if self.start_index is not None else 0 num_samples = self.num_samples if self.num_samples is not None else 0 c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index) c_child_sampler = self.parse_child_for_minddataset() c_sampler.add_child(c_child_sampler) return c_sampler
def is_shuffled(self): if self.child_sampler is None: return False return self.child_sampler.is_shuffled() def is_sharded(self): if self.child_sampler is None: return False return self.child_sampler.is_sharded()
[docs]class SubsetSampler(BuiltinSampler): """ Samples the elements from a sequence of indices. Args: indices (Any iterable Python object but string): A sequence of indices. num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements). Examples: >>> indices = [0, 1, 2, 3, 4, 5] >>> >>> # creates a SubsetSampler, will sample from the provided indices >>> sampler = ds.SubsetSampler(indices) >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, ... num_parallel_workers=8, ... sampler=sampler) Raises: TypeError: If type of indices element is not a number. TypeError: If num_samples is not an integer value. ValueError: If num_samples is a negative value. """ def __init__(self, indices, num_samples=None): def _get_sample_ids_as_list(sampler, number_of_samples=None): if number_of_samples is None: return list(sampler) if isinstance(sampler, list): return sampler[:number_of_samples] return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))] if num_samples is not None: if not isinstance(num_samples, int): raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) if num_samples < 0 or num_samples > validator.INT64_MAX: raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" .format(0, validator.INT64_MAX)) if not isinstance(indices, str) and validator.is_iterable(indices): indices = _get_sample_ids_as_list(indices, num_samples) elif isinstance(indices, int): indices = [indices] else: raise TypeError('Unsupported sampler object of type ({})'.format(type(indices))) for i, item in enumerate(indices): if not isinstance(item, (int, np.integer)): raise TypeError("SubsetSampler: Type of indices element must be int, " "but got list[{}]: {}, type: {}.".format(i, item, type(item))) self.indices = indices super().__init__(num_samples)
[docs] def parse(self): """ Parse the sampler.""" num_samples = self.num_samples if self.num_samples is not None else 0 c_sampler = cde.SubsetSamplerObj(self.indices, num_samples) c_child_sampler = self.parse_child() c_sampler.add_child(c_child_sampler) return c_sampler
def is_shuffled(self): return False def is_sharded(self): if self.child_sampler is None: return False return self.child_sampler.is_sharded()
[docs] def parse_for_minddataset(self): """Parse the sampler for MindRecord.""" c_sampler = cde.MindrecordSubsetSampler(self.indices) c_child_sampler = self.parse_child_for_minddataset() c_sampler.add_child(c_child_sampler) return c_sampler
def get_num_samples(self): num_samples = super().get_num_samples() if num_samples is None: return len(self.indices) return min(len(self.indices), num_samples)
[docs]class SubsetRandomSampler(SubsetSampler): """ Samples the elements randomly from a sequence of indices. Args: indices (Any iterable Python object but string): A sequence of indices. num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements). Examples: >>> indices = [0, 1, 2, 3, 7, 88, 119] >>> >>> # create a SubsetRandomSampler, will sample from the provided indices >>> sampler = ds.SubsetRandomSampler(indices) >>> data = ds.ImageFolderDataset(image_folder_dataset_dir, num_parallel_workers=8, sampler=sampler) Raises: TypeError: If type of indices element is not a number. TypeError: If num_samples is not an integer value. ValueError: If num_samples is a negative value. """
[docs] def parse(self): """ Parse the sampler.""" num_samples = self.num_samples if self.num_samples is not None else 0 c_sampler = cde.SubsetRandomSamplerObj(self.indices, num_samples) c_child_sampler = self.parse_child() c_sampler.add_child(c_child_sampler) return c_sampler
def is_shuffled(self): return True
[docs] def parse_for_minddataset(self): """Parse the sampler for MindRecord.""" c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed()) c_child_sampler = self.parse_child_for_minddataset() c_sampler.add_child(c_child_sampler) return c_sampler
class IterSampler(Sampler): """ User provided an iterable object without inheriting from our Sampler class. Note: This class exists to allow handshake logic between dataset operators and user defined samplers. By constructing this object we avoid the user having to inherit from our Sampler class. Args: sampler (iterable object): an user defined iterable object. num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements). Examples: >>> class MySampler: ... def __iter__(self): ... for i in range(99, -1, -1): ... yield i >>> # creates an IterSampler >>> sampler = ds.IterSampler(sampler=MySampler()) >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, ... num_parallel_workers=8, ... sampler=sampler) """ def __init__(self, sampler, num_samples=None): if num_samples is None: num_samples = len(list(sampler)) super().__init__(num_samples=num_samples) self.sampler = sampler def __iter__(self): return iter(self.sampler)
[docs]class WeightedRandomSampler(BuiltinSampler): """ Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). Args: weights (list[float, int]): A sequence of weights, not necessarily summing up to 1. num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements). replacement (bool): If True, put the sample ID back for the next draw (default=True). Examples: >>> weights = [0.9, 0.01, 0.4, 0.8, 0.1, 0.1, 0.3] >>> >>> # creates a WeightedRandomSampler that will sample 4 elements without replacement >>> sampler = ds.WeightedRandomSampler(weights, 4) >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, ... num_parallel_workers=8, ... sampler=sampler) Raises: TypeError: If type of weights element is not a number. TypeError: If num_samples is not an integer value. TypeError: If replacement is not a boolean value. RuntimeError: If weights is empty or all zero. ValueError: If num_samples is a negative value. """ def __init__(self, weights, num_samples=None, replacement=True): if not isinstance(weights, list): weights = [weights] for ind, w in enumerate(weights): if not isinstance(w, numbers.Number): raise TypeError("type of weights element must be number, " "but got w[{}]: {}, type: {}.".format(ind, w, type(w))) if num_samples is not None: if not isinstance(num_samples, int): raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) if num_samples < 0 or num_samples > validator.INT64_MAX: raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" .format(0, validator.INT64_MAX)) if not isinstance(replacement, bool): raise TypeError("replacement must be a boolean value but was: {}.".format(replacement)) self.weights = weights self.replacement = replacement super().__init__(num_samples)
[docs] def parse(self): """ Parse the sampler.""" num_samples = self.num_samples if self.num_samples is not None else 0 replacement = self.replacement if self.replacement is not None else True c_sampler = cde.WeightedRandomSamplerObj(self.weights, num_samples, replacement) c_child_sampler = self.parse_child() c_sampler.add_child(c_child_sampler) return c_sampler
def is_shuffled(self): return True def is_sharded(self): if self.child_sampler is None: return False return self.child_sampler.is_sharded()