Program Listing for File samplers.h
↰ Return to documentation for file (include/samplers.h
)
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_
#include <memory>
#include <vector>
#include "include/api/types.h"
#include "include/api/status.h"
namespace mindspore {
namespace dataset {
// Forward declare
class SamplerObj;
// Abstract class to represent a sampler in the data pipeline.
class Sampler : std::enable_shared_from_this<Sampler> {
friend class AlbumDataset;
friend class Caltech256Dataset;
friend class CelebADataset;
friend class Cifar10Dataset;
friend class Cifar100Dataset;
friend class CityscapesDataset;
friend class CLUEDataset;
friend class CMUArcticDataset;
friend class CocoDataset;
friend class CSVDataset;
friend class DIV2KDataset;
friend class EMnistDataset;
friend class FakeImageDataset;
friend class FashionMnistDataset;
friend class FlickrDataset;
friend class GTZANDataset;
friend class ImageFolderDataset;
friend class IMDBDataset;
friend class KITTIDataset;
friend class KMnistDataset;
friend class LFWDataset;
friend class LibriTTSDataset;
friend class LJSpeechDataset;
friend class LSUNDataset;
friend class ManifestDataset;
friend class MindDataDataset;
friend class MnistDataset;
friend class OmniglotDataset;
friend class PhotoTourDataset;
friend class Places365Dataset;
friend class QMnistDataset;
friend class RandomDataDataset;
friend class SBUDataset;
friend class SemeionDataset;
friend class SpeechCommandsDataset;
friend class STL10Dataset;
friend class TedliumDataset;
friend class TextFileDataset;
friend class TFRecordDataset;
friend class USPSDataset;
friend class VOCDataset;
friend class WIDERFaceDataset;
friend class YesNoDataset;
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
Sampler() = default;
virtual ~Sampler() = default;
virtual void AddChild(const std::shared_ptr<Sampler> &child) { children_.push_back(child); }
protected:
virtual std::shared_ptr<SamplerObj> Parse() const = 0;
Status BuildChildren(std::shared_ptr<SamplerObj> *const sampler) const;
std::vector<std::shared_ptr<Sampler>> children_;
};
class DistributedSampler final : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, int64_t num_samples = 0,
uint32_t seed = 1, int64_t offset = -1, bool even_dist = true);
~DistributedSampler() = default;
protected:
std::shared_ptr<SamplerObj> Parse() const override;
private:
int64_t num_shards_;
int64_t shard_id_;
bool shuffle_;
int64_t num_samples_;
uint32_t seed_;
int64_t offset_;
bool even_dist_;
};
class PKSampler final : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
explicit PKSampler(int64_t num_val, bool shuffle = false, int64_t num_samples = 0);
~PKSampler() = default;
protected:
std::shared_ptr<SamplerObj> Parse() const override;
private:
int64_t num_val_;
bool shuffle_;
int64_t num_samples_;
};
class RandomSampler final : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
explicit RandomSampler(bool replacement = false, int64_t num_samples = 0);
~RandomSampler() = default;
protected:
std::shared_ptr<SamplerObj> Parse() const override;
private:
bool replacement_;
int64_t num_samples_;
};
class SequentialSampler final : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
explicit SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0);
~SequentialSampler() = default;
protected:
std::shared_ptr<SamplerObj> Parse() const override;
private:
int64_t start_index_;
int64_t num_samples_;
};
class SubsetSampler : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
explicit SubsetSampler(const std::vector<int64_t> &indices, int64_t num_samples = 0);
~SubsetSampler() = default;
protected:
std::shared_ptr<SamplerObj> Parse() const override;
std::vector<int64_t> indices_;
int64_t num_samples_;
};
class SubsetRandomSampler final : public SubsetSampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
explicit SubsetRandomSampler(const std::vector<int64_t> &indices, int64_t num_samples = 0);
~SubsetRandomSampler() = default;
protected:
std::shared_ptr<SamplerObj> Parse() const override;
};
class WeightedRandomSampler final : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
explicit WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples = 0, bool replacement = true);
~WeightedRandomSampler() = default;
protected:
std::shared_ptr<SamplerObj> Parse() const override;
private:
std::vector<double> weights_;
int64_t num_samples_;
bool replacement_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_