mindspore_gl.dataloader.samplers 源代码

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Implement various data sampler."""
import random
import mindspore.dataset as ds


[文档]class RandomBatchSampler(ds.Sampler): """ Random Batched Node Sampler, random sample nodes form graph. The remained sample will be dropped. Args: data_source(Union[List, Tuple, Iterable]): data source sample from batch_size(int): number of sampling subgraphs per batch Raises: TypeError: If `batch_size` is not a positive integer. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> from mindspore_gl.dataloader.samplers import RandomBatchSampler >>> ds = list(range(10)) >>> sampler = RandomBatchSampler(ds, 3) >>> print(list(sampler)) # results will be random for suffle [[5, 9, 3], [4, 6, 7], [2, 8, 1]] """ def __init__(self, data_source, batch_size): super().__init__() self.data_source = data_source self.batch_size = batch_size if self.data_source is None: self.data_source = [] if isinstance(self.data_source, tuple): self.data_source = list(self.data_source) if not isinstance(self.batch_size, int) or self.batch_size < 0: raise TypeError("batch_size should be a positive integer value," "but got batch_size = {}.".format(self.batch_size)) self.epoch = 1 def _node_iter(self): data_length = len(self.data_source) for i in range(0, data_length, self.batch_size): # Drop reminder if i + self.batch_size <= data_length: yield self.data_source[i: i + self.batch_size] def __iter__(self): # Reset random seed here if necessary self.epoch += 1 random.seed(self.epoch) random.shuffle(self.data_source) return self._node_iter() def __len__(self): return len(self.data_source) // self.batch_size
class DistributeRandomBatchSampler(ds.Sampler): """ Distribute Random Batch Sampler Args: rank(int): Rank of the current process within distributed group, less than `world_size` world_size(int): Number of processes in distributed computing data_source(Union[List, Tuple, Iterable]): data source sample from batch_size(int): number of sampling subgraphs per batch Raises: TypeError: If `batch_size` is not a positive integer. TypeError: If `rank` is negative or not an integer or `rank` value greater than `work_size`. TypeError: If `work_size` is not a positive integer. Examples: >>> from mindspore_gl.dataloader.samplers import DistributeRandomBatchSampler >>> ds = list(range(20)) >>> rank_id = 0 >>> world_size = 2 >>> sampler = DistributeRandomBatchSampler(rank_id, world_size, ds, 3) >>> print(list(sampler)) # results will be random for suffle [[10, 18, 6], [8, 12, 14], [4, 16, 2]] """ def __init__(self, rank, world_size, data_source, batch_size): super().__init__() if data_source is None: data_source = [] if isinstance(data_source, tuple): data_source = list(data_source) self.data_source_rank = data_source[rank::world_size] self.batch_size = batch_size self.epoch = 1 self.rank = rank self.world_size = world_size if not isinstance(self.batch_size, int) or self.batch_size < 0: raise TypeError("batch_size should be a positive integer value," "but got batch_size = {}.".format(self.batch_size)) if not isinstance(self.world_size, int) or self.world_size < 0: raise TypeError("world_size should be a positive integer value," "but got world_size = {}.".format(self.world_size)) if not isinstance(self.rank, int) or self.rank < 0 or self.rank >= self.world_size: raise TypeError("rank should be a positive integer value less than work_size," "but got rank = {}.".format(self.rank)) def node_iter(self): data_length = len(self.data_source_rank) for i in range(0, data_length, self.batch_size): # Drop reminder if i + self.batch_size <= data_length: yield self.data_source_rank[i: i + self.batch_size] def __iter__(self): # Reset random seed here if necessary self.epoch += 1 random.seed(self.epoch) random.shuffle(self.data_source_rank) return self.node_iter() def __len__(self): return (len(self.data_source_rank) + self.batch_size - 1) // self.batch_size