Program Listing for File datasets.h

Return to documentation for file (include/runtime/include/dataset/datasets.h)

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_LITEAPI_INCLUDE_DATASETS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_LITEAPI_INCLUDE_DATASETS_H_

#include <sys/stat.h>
#include <unistd.h>

#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include "include/api/dual_abi_helper.h"
#include "include/api/types.h"
#include "include/dataset/iterator.h"
#include "include/dataset/samplers.h"
#include "include/dataset/transforms.h"

namespace mindspore {
namespace dataset {
class Tensor;
class TensorShape;
class TreeAdapter;
class TreeAdapterLite;
class TreeGetters;

class DatasetCache;
class DatasetNode;

class Iterator;

class TensorOperation;
class SchemaObj;
class SamplerObj;

// Dataset classes (in alphabetical order)
class BatchDataset;
class MapDataset;
class ProjectDataset;
class ShuffleDataset;
class DSCallback;

class DATASET_API Dataset : public std::enable_shared_from_this<Dataset> {
 public:
  // need friend class so they can access the children_ field
  friend class Iterator;
  friend class DataQueueNode;

  Dataset();

  virtual ~Dataset() = default;

  int64_t GetDatasetSize(bool estimate = false);

  std::vector<mindspore::DataType> GetOutputTypes();

  std::vector<std::vector<int64_t>> GetOutputShapes();

  int64_t GetBatchSize();

  int64_t GetRepeatCount();

  int64_t GetNumClasses();

  std::vector<std::string> GetColumnNames() { return VectorCharToString(GetColumnNamesCharIF()); }

  std::vector<std::pair<std::string, std::vector<int32_t>>> GetClassIndexing() {
    return ClassIndexCharToString(GetClassIndexingCharIF());
  }

  std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers);

  std::shared_ptr<PullIterator> CreatePullBasedIterator();

  std::shared_ptr<Iterator> CreateIterator(int32_t num_epochs = -1) { return CreateIteratorCharIF(num_epochs); }

  bool DeviceQueue(const std::string &queue_name = "", const std::string &device_type = "", int32_t device_id = 0,
                   int32_t num_epochs = -1, bool send_epoch_end = true, int32_t total_batches = 0,
                   bool create_data_info_queue = false) {
    return DeviceQueueCharIF(StringToChar(queue_name), StringToChar(device_type), device_id, num_epochs, send_epoch_end,
                             total_batches, create_data_info_queue);
  }

  bool Save(const std::string &dataset_path, int32_t num_files = 1, const std::string &dataset_type = "mindrecord") {
    return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type));
  }

  std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);

  std::shared_ptr<MapDataset> Map(const std::vector<TensorTransform *> &operations,
                                  const std::vector<std::string> &input_columns = {},
                                  const std::vector<std::string> &output_columns = {},
                                  const std::shared_ptr<DatasetCache> &cache = nullptr,
                                  const std::vector<std::shared_ptr<DSCallback>> &callbacks = {}) {
    std::vector<std::shared_ptr<TensorOperation>> transform_ops;
    (void)std::transform(
      operations.begin(), operations.end(), std::back_inserter(transform_ops),
      [](TensorTransform *op) -> std::shared_ptr<TensorOperation> { return op != nullptr ? op->Parse() : nullptr; });
    return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
                                        VectorStringToChar(output_columns), cache, callbacks);
  }

  std::shared_ptr<MapDataset> Map(const std::vector<std::shared_ptr<TensorTransform>> &operations,
                                  const std::vector<std::string> &input_columns = {},
                                  const std::vector<std::string> &output_columns = {},
                                  const std::shared_ptr<DatasetCache> &cache = nullptr,
                                  const std::vector<std::shared_ptr<DSCallback>> &callbacks = {}) {
    std::vector<std::shared_ptr<TensorOperation>> transform_ops;
    (void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
                         [](const std::shared_ptr<TensorTransform> &op) -> std::shared_ptr<TensorOperation> {
                           return op != nullptr ? op->Parse() : nullptr;
                         });
    return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
                                        VectorStringToChar(output_columns), cache, callbacks);
  }

  std::shared_ptr<MapDataset> Map(const std::vector<std::reference_wrapper<TensorTransform>> &operations,
                                  const std::vector<std::string> &input_columns = {},
                                  const std::vector<std::string> &output_columns = {},
                                  const std::shared_ptr<DatasetCache> &cache = nullptr,
                                  const std::vector<std::shared_ptr<DSCallback>> &callbacks = {}) {
    std::vector<std::shared_ptr<TensorOperation>> transform_ops;
    (void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
                         [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
    return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
                                        VectorStringToChar(output_columns), cache, callbacks);
  }

  std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns) {
    return std::make_shared<ProjectDataset>(shared_from_this(), VectorStringToChar(columns));
  }

  std::shared_ptr<ShuffleDataset> Shuffle(int32_t buffer_size) {
    return std::make_shared<ShuffleDataset>(shared_from_this(), buffer_size);
  }

  std::shared_ptr<DatasetNode> IRNode() { return ir_node_; }

 protected:
  std::shared_ptr<TreeGetters> tree_getters_;
  std::shared_ptr<DatasetNode> ir_node_;

 private:
  // Char interface(CharIF) of GetColumnNames
  std::vector<std::vector<char>> GetColumnNamesCharIF();

  // Char interface(CharIF) of GetClassIndexing
  std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> GetClassIndexingCharIF();

  // Char interface(CharIF) of CreateIterator
  std::shared_ptr<Iterator> CreateIteratorCharIF(int32_t num_epochs);

  // Char interface(CharIF) of DeviceQueue
  bool DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type, int32_t device_id,
                         int32_t num_epochs, bool send_epoch_end, int32_t total_batches, bool create_data_info_queue);

  // Char interface(CharIF) of Save
  bool SaveCharIF(const std::vector<char> &dataset_path, int32_t num_files, const std::vector<char> &dataset_type);
};

class DATASET_API SchemaObj {
 public:
  explicit SchemaObj(const std::string &schema_file = "") : SchemaObj(StringToChar(schema_file)) {}

  ~SchemaObj() = default;

  Status Init();

  Status add_column(const std::string &name, mindspore::DataType ms_type) {
    return add_column_char(StringToChar(name), ms_type);
  }

  Status add_column(const std::string &name, const std::string &ms_type) {
    return add_column_char(StringToChar(name), StringToChar(ms_type));
  }

  Status add_column(const std::string &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape) {
    return add_column_char(StringToChar(name), ms_type, shape);
  }

  Status add_column(const std::string &name, const std::string &ms_type, const std::vector<int32_t> &shape) {
    return add_column_char(StringToChar(name), StringToChar(ms_type), shape);
  }

  std::string to_json() { return CharToString(to_json_char()); }

  std::string to_string() { return to_json(); }

  void set_dataset_type(const std::string &dataset_type);

  void set_num_rows(int32_t num_rows);

  int32_t get_num_rows() const;

  Status FromJSONString(const std::string &json_string) { return FromJSONStringCharIF(StringToChar(json_string)); }

  Status ParseColumnString(const std::string &json_string) {
    return ParseColumnStringCharIF(StringToChar(json_string));
  }

 private:
  // Char constructor of SchemaObj
  explicit SchemaObj(const std::vector<char> &schema_file);

  // Char interface of add_column
  Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type);

  Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type);

  Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape);

  Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type,
                         const std::vector<int32_t> &shape);

  // Char interface of to_json
  const std::vector<char> to_json_char();

  // Char interface of FromJSONString
  Status FromJSONStringCharIF(const std::vector<char> &json_string);

  // Char interface of ParseColumnString
  Status ParseColumnStringCharIF(const std::vector<char> &json_string);

  struct Data;
  std::shared_ptr<Data> data_;
};

class DATASET_API BatchDataset : public Dataset {
 public:
  BatchDataset(const std::shared_ptr<Dataset> &input, int32_t batch_size, bool drop_remainder = false);

  ~BatchDataset() override = default;
};

class DATASET_API MapDataset : public Dataset {
 public:
  MapDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::shared_ptr<TensorOperation>> &operations,
             const std::vector<std::vector<char>> &input_columns, const std::vector<std::vector<char>> &output_columns,
             const std::shared_ptr<DatasetCache> &cache, const std::vector<std::shared_ptr<DSCallback>> &callbacks);

  ~MapDataset() override = default;
};

class DATASET_API ProjectDataset : public Dataset {
 public:
  ProjectDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::vector<char>> &columns);

  ~ProjectDataset() override = default;
};

class DATASET_API ShuffleDataset : public Dataset {
 public:
  ShuffleDataset(const std::shared_ptr<Dataset> &input, int32_t buffer_size);

  ~ShuffleDataset() override = default;
};

std::shared_ptr<SchemaObj> DATASET_API SchemaCharIF(const std::vector<char> &schema_file);

inline std::shared_ptr<SchemaObj> DATASET_API Schema(const std::string &schema_file = "") {
  return SchemaCharIF(StringToChar(schema_file));
}

class DATASET_API AlbumDataset : public Dataset {
 public:
  AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
               const std::vector<std::vector<char>> &column_names, bool decode, const std::shared_ptr<Sampler> &sampler,
               const std::shared_ptr<DatasetCache> &cache);

  AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
               const std::vector<std::vector<char>> &column_names, bool decode, const Sampler *sampler,
               const std::shared_ptr<DatasetCache> &cache);

  AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
               const std::vector<std::vector<char>> &column_names, bool decode,
               const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);

  ~AlbumDataset() override = default;
};

inline std::shared_ptr<AlbumDataset> DATASET_API
Album(const std::string &dataset_dir, const std::string &data_schema, const std::vector<std::string> &column_names = {},
      bool decode = false, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
      const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
                                        VectorStringToChar(column_names), decode, sampler, cache);
}

inline std::shared_ptr<AlbumDataset> DATASET_API Album(const std::string &dataset_dir, const std::string &data_schema,
                                                       const std::vector<std::string> &column_names, bool decode,
                                                       const Sampler *sampler,
                                                       const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
                                        VectorStringToChar(column_names), decode, sampler, cache);
}

inline std::shared_ptr<AlbumDataset> DATASET_API Album(const std::string &dataset_dir, const std::string &data_schema,
                                                       const std::vector<std::string> &column_names, bool decode,
                                                       const std::reference_wrapper<Sampler> sampler,
                                                       const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
                                        VectorStringToChar(column_names), decode, sampler, cache);
}

class DATASET_API MnistDataset : public Dataset {
 public:
  MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
               const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);

  MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
               const std::shared_ptr<DatasetCache> &cache);

  MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
               const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);

  ~MnistDataset() override = default;
};

inline std::shared_ptr<MnistDataset> DATASET_API
Mnist(const std::string &dataset_dir, const std::string &usage = "all",
      const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
      const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}

inline std::shared_ptr<MnistDataset> DATASET_API Mnist(const std::string &dataset_dir, const std::string &usage,
                                                       const Sampler *sampler,
                                                       const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}

inline std::shared_ptr<MnistDataset> DATASET_API Mnist(const std::string &dataset_dir, const std::string &usage,
                                                       const std::reference_wrapper<Sampler> sampler,
                                                       const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
}  // namespace dataset
}  // namespace mindspore
#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_