Program Listing for File samplers.h

Return to documentation for file (include/samplers.h)

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_

#include <memory>
#include <vector>

#include "include/api/types.h"

namespace mindspore {
namespace dataset {

// Forward declare
class SamplerObj;

// Abstract class to represent a sampler in the data pipeline.
class  Sampler : std::enable_shared_from_this<Sampler> {
  friend class AlbumDataset;
  friend class CelebADataset;
  friend class Cifar10Dataset;
  friend class Cifar100Dataset;
  friend class CityscapesDataset;
  friend class CLUEDataset;
  friend class CocoDataset;
  friend class CSVDataset;
  friend class DIV2KDataset;
  friend class FlickrDataset;
  friend class ImageFolderDataset;
  friend class ManifestDataset;
  friend class MindDataDataset;
  friend class MnistDataset;
  friend class RandomDataDataset;
  friend class SBUDataset;
  friend class TextFileDataset;
  friend class TFRecordDataset;
  friend class USPSDataset;
  friend class VOCDataset;
  friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);

 public:
  Sampler() {}

  ~Sampler() = default;

  virtual void AddChild(std::shared_ptr<Sampler> child) { children_.push_back(child); }

 protected:
  virtual std::shared_ptr<SamplerObj> Parse() const = 0;

  std::vector<std::shared_ptr<Sampler>> children_;
};

class  DistributedSampler final : public Sampler {
  friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);

 public:
  DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, int64_t num_samples = 0,
                     uint32_t seed = 1, int64_t offset = -1, bool even_dist = true);
  ~DistributedSampler() = default;

 protected:
  std::shared_ptr<SamplerObj> Parse() const override;

 private:
  int64_t num_shards_;
  int64_t shard_id_;
  bool shuffle_;
  int64_t num_samples_;
  uint32_t seed_;
  int64_t offset_;
  bool even_dist_;
};

class  PKSampler final : public Sampler {
  friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);

 public:
  explicit PKSampler(int64_t num_val, bool shuffle = false, int64_t num_samples = 0);

  ~PKSampler() = default;

 protected:
  std::shared_ptr<SamplerObj> Parse() const override;

 private:
  int64_t num_val_;
  bool shuffle_;
  int64_t num_samples_;
};

class  RandomSampler final : public Sampler {
  friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);

 public:
  explicit RandomSampler(bool replacement = false, int64_t num_samples = 0);

  ~RandomSampler() = default;

 protected:
  std::shared_ptr<SamplerObj> Parse() const override;

 private:
  bool replacement_;
  int64_t num_samples_;
};

class  SequentialSampler final : public Sampler {
  friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);

 public:
  explicit SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0);

  ~SequentialSampler() = default;

 protected:
  std::shared_ptr<SamplerObj> Parse() const override;

 private:
  int64_t start_index_;
  int64_t num_samples_;
};

class  SubsetSampler : public Sampler {
  friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);

 public:
  explicit SubsetSampler(std::vector<int64_t> indices, int64_t num_samples = 0);

  ~SubsetSampler() = default;

 protected:
  std::shared_ptr<SamplerObj> Parse() const override;

  std::vector<int64_t> indices_;
  int64_t num_samples_;
};

class  SubsetRandomSampler final : public SubsetSampler {
  friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);

 public:
  explicit SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples = 0);

  ~SubsetRandomSampler() = default;

 protected:
  std::shared_ptr<SamplerObj> Parse() const override;
};

class  WeightedRandomSampler final : public Sampler {
  friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);

 public:
  explicit WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);

  ~WeightedRandomSampler() = default;

 protected:
  std::shared_ptr<SamplerObj> Parse() const override;

 private:
  std::vector<double> weights_;
  int64_t num_samples_;
  bool replacement_;
};

}  // namespace dataset
}  // namespace mindspore
#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_