比较与torch.utils.data.SubsetRandomSampler的差异

查看源文件

torch.utils.data.SubsetRandomSampler

class torch.utils.data.SubsetRandomSampler(indices, generator=None)

更多内容详见torch.utils.data.SubsetRandomSampler

mindspore.dataset.SubsetRandomSampler

class mindspore.dataset.SubsetRandomSampler(indices, num_samples=None)

更多内容详见mindspore.dataset.SubsetRandomSampler

差异对比

PyTorch:给定样本的索引序列,从序列中随机获取索引对数据集进行采样,支持指定采样逻辑。

MindSpore:给定样本的索引序列,从序列中随机获取索引对数据集进行采样,不支持指定采样逻辑。

分类

子类

PyTorch

MindSpore

差异

参数

参数1

indices

indices

-

参数2

generator

-

指定额外的采样逻辑,MindSpore为全局随机采样

参数3

-

num_samples

获取的样本数,可用于部分获取采样得到的样本

代码示例

import torch
from torch.utils.data import SubsetRandomSampler

torch.manual_seed(0)

class MyMapDataset(torch.utils.data.Dataset):
    def __init__(self):
        super(MyMapDataset).__init__()
        self.data = [i for i in range(4)]
    def __getitem__(self, index):
        return self.data[index]
    def __len__(self):
        return len(self.data)

ds = MyMapDataset()
sampler = SubsetRandomSampler(indices=[0, 2])
dataloader = torch.utils.data.DataLoader(ds, sampler=sampler)

for data in dataloader:
    print(data)
# Out:
# tensor([2])
# tensor([0])
import mindspore as ms
from mindspore.dataset import SubsetRandomSampler

ms.dataset.config.set_seed(1)

class MyMapDataset():
    def __init__(self):
        super(MyMapDataset).__init__()
        self.data = [i for i in range(4)]
    def __getitem__(self, index):
        return self.data[index]
    def __len__(self):
        return len(self.data)

ds = MyMapDataset()
sampler = SubsetRandomSampler(indices=[0, 2])
dataloader = ms.dataset.GeneratorDataset(ds, column_names=["data"], sampler=sampler)

for data in dataloader:
    print(data)
# Out:
# [Tensor(shape=[], dtype=Int64, value= 2)]
# [Tensor(shape=[], dtype=Int64, value= 0)]