Differences with torch.utils.data.WeightedRandomSampler

View Source On Gitee

torch.utils.data.WeightedRandomSampler

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)

For more information, see torch.utils.data.WeightedRandomSampler.

mindspore.dataset.WeightedRandomSampler

class mindspore.dataset.WeightedRandomSampler(weights, num_samples=None, replacement=True)

For more information, see mindspore.dataset.WeightedRandomSampler.

Differences

PyTorch: Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities), random generator can be set manually.

MindSpore: Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities), random generator is not supported.

Categories

Subcategories

PyTorch

MindSpore

Difference

Parameter

Parameter1

weights

weights

-

Parameter2

num_samples

num_samples

-

Parameter3

replacement

replacement

-

Parameter4

generator

-

Specifies sampling logic. MindSpore uses global random sampling.

Code Example

import torch
from torch.utils.data import WeightedRandomSampler

torch.manual_seed(0)

class MyMapDataset(torch.utils.data.Dataset):
    def __init__(self):
        super(MyMapDataset).__init__()
        self.data = [i for i in range(1, 5)]
    def __getitem__(self, index):
        return self.data[index]
    def __len__(self):
        return len(self.data)

ds = MyMapDataset()
sampler = WeightedRandomSampler(weights=[0.1, 0.1, 0.9, 0.9], num_samples=4)
dataloader = torch.utils.data.DataLoader(ds, sampler=sampler)

for data in dataloader:
    print(data)
# Out:
# tensor([4])
# tensor([3])
# tensor([4])
# tensor([4])
import mindspore as ms
from mindspore.dataset import WeightedRandomSampler

ms.dataset.config.set_seed(4)

class MyMapDataset():
    def __init__(self):
        super(MyMapDataset).__init__()
        self.data = [i for i in range(1, 5)]
    def __getitem__(self, index):
        return self.data[index]
    def __len__(self):
        return len(self.data)

ds = MyMapDataset()
sampler = WeightedRandomSampler(weights=[0.1, 0.1, 0.9, 0.9], num_samples=4)
dataloader = ms.dataset.GeneratorDataset(ds, column_names=["data"], sampler=sampler)

for data in dataloader:
    print(data)
# Out:
# [Tensor(shape=[], dtype=Int64, value= 4)]
# [Tensor(shape=[], dtype=Int64, value= 3)]
# [Tensor(shape=[], dtype=Int64, value= 4)]
# [Tensor(shape=[], dtype=Int64, value= 4)]