Program Listing for File samplers.h

Return to documentation for file (include/runtime/include/dataset/samplers.h)

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_

#include <memory>
#include <vector>

#include "include/api/types.h"
#include "include/api/status.h"

namespace mindspore {
namespace dataset {
// Forward declare
class SamplerObj;

// Abstract class to represent a sampler in the data pipeline.
class DATASET_API Sampler : std::enable_shared_from_this<Sampler> {
  friend class AlbumDataset;
  friend class Caltech256Dataset;
  friend class CelebADataset;
  friend class Cifar10Dataset;
  friend class Cifar100Dataset;
  friend class CityscapesDataset;
  friend class CLUEDataset;
  friend class CMUArcticDataset;
  friend class CocoDataset;
  friend class CSVDataset;
  friend class DIV2KDataset;
  friend class EMnistDataset;
  friend class FakeImageDataset;
  friend class FashionMnistDataset;
  friend class FlickrDataset;
  friend class Food101Dataset;
  friend class GTZANDataset;
  friend class ImageFolderDataset;
  friend class IMDBDataset;
  friend class KITTIDataset;
  friend class KMnistDataset;
  friend class LFWDataset;
  friend class LibriTTSDataset;
  friend class LJSpeechDataset;
  friend class LSUNDataset;
  friend class ManifestDataset;
  friend class MindDataDataset;
  friend class MnistDataset;
  friend class OmniglotDataset;
  friend class PhotoTourDataset;
  friend class Places365Dataset;
  friend class QMnistDataset;
  friend class RandomDataDataset;
  friend class RenderedSST2Dataset;
  friend class SBUDataset;
  friend class SemeionDataset;
  friend class SpeechCommandsDataset;
  friend class SST2Dataset;
  friend class STL10Dataset;
  friend class SUN397Dataset;
  friend class TedliumDataset;
  friend class TextFileDataset;
  friend class TFRecordDataset;
  friend class USPSDataset;
  friend class VOCDataset;
  friend class WIDERFaceDataset;
  friend class YesNoDataset;
  friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards,
                                                   int32_t shard_id);

 public:
  Sampler() = default;

  virtual ~Sampler() = default;

  virtual void AddChild(const std::shared_ptr<Sampler> &child) { children_.push_back(child); }

 protected:
  virtual std::shared_ptr<SamplerObj> Parse() const = 0;

  Status BuildChildren(std::shared_ptr<SamplerObj> *const sampler) const;

  std::vector<std::shared_ptr<Sampler>> children_;
};

class DATASET_API DistributedSampler final : public Sampler {
  friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards,
                                                   int32_t shard_id);

 public:
  DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, int64_t num_samples = 0,
                     uint32_t seed = 1, int64_t offset = -1, bool even_dist = true);
  ~DistributedSampler() = default;

 protected:
  std::shared_ptr<SamplerObj> Parse() const override;

 private:
  int64_t num_shards_;
  int64_t shard_id_;
  bool shuffle_;
  int64_t num_samples_;
  uint32_t seed_;
  int64_t offset_;
  bool even_dist_;
};

class DATASET_API PKSampler final : public Sampler {
  friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards,
                                                   int32_t shard_id);

 public:
  explicit PKSampler(int64_t num_val, bool shuffle = false, int64_t num_samples = 0);

  ~PKSampler() = default;

 protected:
  std::shared_ptr<SamplerObj> Parse() const override;

 private:
  int64_t num_val_;
  bool shuffle_;
  int64_t num_samples_;
};

class DATASET_API RandomSampler final : public Sampler {
  friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards,
                                                   int32_t shard_id);

 public:
  explicit RandomSampler(bool replacement = false, int64_t num_samples = 0);

  ~RandomSampler() = default;

 protected:
  std::shared_ptr<SamplerObj> Parse() const override;

 private:
  bool replacement_;
  int64_t num_samples_;
};

class DATASET_API SequentialSampler final : public Sampler {
  friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards,
                                                   int32_t shard_id);

 public:
  explicit SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0);

  ~SequentialSampler() = default;

 protected:
  std::shared_ptr<SamplerObj> Parse() const override;

 private:
  int64_t start_index_;
  int64_t num_samples_;
};

class DATASET_API SubsetSampler : public Sampler {
  friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards,
                                                   int32_t shard_id);

 public:
  explicit SubsetSampler(const std::vector<int64_t> &indices, int64_t num_samples = 0);

  ~SubsetSampler() = default;

 protected:
  std::shared_ptr<SamplerObj> Parse() const override;

  std::vector<int64_t> indices_;
  int64_t num_samples_;
};

class DATASET_API SubsetRandomSampler final : public SubsetSampler {
  friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards,
                                                   int32_t shard_id);

 public:
  explicit SubsetRandomSampler(const std::vector<int64_t> &indices, int64_t num_samples = 0);

  ~SubsetRandomSampler() = default;

 protected:
  std::shared_ptr<SamplerObj> Parse() const override;
};

class DATASET_API WeightedRandomSampler final : public Sampler {
  friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards,
                                                   int32_t shard_id);

 public:
  explicit WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples = 0, bool replacement = true);

  ~WeightedRandomSampler() = default;

 protected:
  std::shared_ptr<SamplerObj> Parse() const override;

 private:
  std::vector<double> weights_;
  int64_t num_samples_;
  bool replacement_;
};
}  // namespace dataset
}  // namespace mindspore
#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_