Program Listing for File samplers.h
↰ Return to documentation for file (include/samplers.h
)
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_
#include <memory>
#include <vector>
namespace mindspore {
namespace dataset {
// Forward declare
class SamplerObj;
// Abstract class to represent a sampler in the data pipeline.
class Sampler : std::enable_shared_from_this<Sampler> {
friend class AlbumDataset;
friend class CelebADataset;
friend class Cifar10Dataset;
friend class Cifar100Dataset;
friend class CLUEDataset;
friend class CocoDataset;
friend class CSVDataset;
friend class ImageFolderDataset;
friend class ManifestDataset;
friend class MindDataDataset;
friend class MnistDataset;
friend class RandomDataDataset;
friend class TextFileDataset;
friend class TFRecordDataset;
friend class VOCDataset;
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
Sampler() {}
~Sampler() = default;
virtual void AddChild(std::shared_ptr<Sampler> child) { children_.push_back(child); }
protected:
virtual std::shared_ptr<SamplerObj> Parse() const = 0;
std::vector<std::shared_ptr<Sampler>> children_;
};
class DistributedSampler final : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, int64_t num_samples = 0,
uint32_t seed = 1, int64_t offset = -1, bool even_dist = true);
~DistributedSampler() = default;
protected:
std::shared_ptr<SamplerObj> Parse() const override;
private:
int64_t num_shards_;
int64_t shard_id_;
bool shuffle_;
int64_t num_samples_;
uint32_t seed_;
int64_t offset_;
bool even_dist_;
};
class PKSampler final : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
explicit PKSampler(int64_t num_val, bool shuffle = false, int64_t num_samples = 0);
~PKSampler() = default;
protected:
std::shared_ptr<SamplerObj> Parse() const override;
private:
int64_t num_val_;
bool shuffle_;
int64_t num_samples_;
};
class RandomSampler final : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
explicit RandomSampler(bool replacement = false, int64_t num_samples = 0);
~RandomSampler() = default;
protected:
std::shared_ptr<SamplerObj> Parse() const override;
private:
bool replacement_;
int64_t num_samples_;
};
class SequentialSampler final : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
explicit SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0);
~SequentialSampler() = default;
protected:
std::shared_ptr<SamplerObj> Parse() const override;
private:
int64_t start_index_;
int64_t num_samples_;
};
class SubsetSampler : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
explicit SubsetSampler(std::vector<int64_t> indices, int64_t num_samples = 0);
~SubsetSampler() = default;
protected:
std::shared_ptr<SamplerObj> Parse() const override;
std::vector<int64_t> indices_;
int64_t num_samples_;
};
class SubsetRandomSampler final : public SubsetSampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
explicit SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples = 0);
~SubsetRandomSampler() = default;
protected:
std::shared_ptr<SamplerObj> Parse() const override;
};
class WeightedRandomSampler final : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
explicit WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);
~WeightedRandomSampler() = default;
protected:
std::shared_ptr<SamplerObj> Parse() const override;
private:
std::vector<double> weights_;
int64_t num_samples_;
bool replacement_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_