Program Listing for File datasets.h

Return to documentation for file (include/datasets.h)

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_

#include <sys/stat.h>
#include <unistd.h>

#include <algorithm>
#include <functional>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include "include/api/dual_abi_helper.h"
#include "include/api/types.h"
#include "include/dataset/iterator.h"
#include "include/dataset/json_fwd.hpp"
#include "include/dataset/samplers.h"
#include "include/dataset/text.h"

namespace mindspore {
namespace dataset {

class Tensor;
class TensorShape;
class TreeAdapter;
class TreeAdapterLite;
class TreeGetters;
class Vocab;

class DatasetCache;
class DatasetNode;

class Iterator;
class PullBasedIterator;

class TensorOperation;
class SchemaObj;
class SamplerObj;
class CsvBase;

// Dataset classes (in alphabetical order)
class BatchDataset;
class MapDataset;
class ProjectDataset;
class ShuffleDataset;
class BucketBatchByLengthDataset;
class FilterDataset;
class CSVDataset;
class TransferDataset;
class ConcatDataset;
class RenameDataset;

class SentencePieceVocab;
enum class SentencePieceModel;

class DSCallback;

class RepeatDataset;
class SkipDataset;
class TakeDataset;
class ZipDataset;

class  Dataset : public std::enable_shared_from_this<Dataset> {
 public:
  // need friend class so they can access the children_ field
  friend class Iterator;
  friend class TransferNode;

  Dataset();

  ~Dataset() = default;

  int64_t GetDatasetSize(bool estimate = false);

  std::vector<mindspore::DataType> GetOutputTypes();

  std::vector<std::vector<int64_t>> GetOutputShapes();

  int64_t GetBatchSize();

  int64_t GetRepeatCount();

  int64_t GetNumClasses();

  std::vector<std::string> GetColumnNames() { return VectorCharToString(GetColumnNamesCharIF()); }

  std::vector<std::pair<std::string, std::vector<int32_t>>> GetClassIndexing() {
    return ClassIndexCharToString(GetClassIndexingCharIF());
  }

  std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers);

  std::shared_ptr<PullIterator> CreatePullBasedIterator(std::vector<std::vector<char>> columns = {});

  std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {}, int32_t num_epochs = -1) {
    return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs);
  }

  bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t device_id = 0,
                   int32_t num_epochs = -1, bool send_epoch_end = true, int32_t total_batches = 0,
                   bool create_data_info_queue = false) {
    return DeviceQueueCharIF(StringToChar(queue_name), StringToChar(device_type), device_id, num_epochs, send_epoch_end,
                             total_batches, create_data_info_queue);
  }

  bool Save(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") {
    return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type));
  }

  std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);

  std::shared_ptr<BucketBatchByLengthDataset> BucketBatchByLength(
    const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries,
    const std::vector<int32_t> &bucket_batch_sizes,
    std::function<MSTensorVec(MSTensorVec)> element_length_function = nullptr,
    const std::map<std::string, std::pair<std::vector<int64_t>, MSTensor>> &pad_info = {},
    bool pad_to_bucket_boundary = false, bool drop_remainder = false) {
    return std::make_shared<BucketBatchByLengthDataset>(
      shared_from_this(), VectorStringToChar(column_names), bucket_boundaries, bucket_batch_sizes,
      element_length_function, PadInfoStringToChar(pad_info), pad_to_bucket_boundary, drop_remainder);
  }

  std::shared_ptr<SentencePieceVocab> BuildSentencePieceVocab(
    const std::vector<std::string> &col_names, int32_t vocab_size, float character_coverage,
    SentencePieceModel model_type, const std::unordered_map<std::string, std::string> &params) {
    return BuildSentencePieceVocabCharIF(VectorStringToChar(col_names), vocab_size, character_coverage, model_type,
                                         UnorderedMapStringToChar(params));
  }

  std::shared_ptr<Vocab> BuildVocab(const std::vector<std::string> &columns = {},
                                    const std::pair<int64_t, int64_t> &freq_range = {0, kDeMaxFreq},
                                    int64_t top_k = kDeMaxTopk, const std::vector<std::string> &special_tokens = {},
                                    bool special_first = true) {
    return BuildVocabCharIF(VectorStringToChar(columns), freq_range, top_k, VectorStringToChar(special_tokens),
                            special_first);
  }

  std::shared_ptr<ConcatDataset> Concat(const std::vector<std::shared_ptr<Dataset>> &datasets) {
    std::vector<std::shared_ptr<Dataset>> all_datasets{shared_from_this()};
    all_datasets.insert(std::end(all_datasets), std::begin(datasets), std::end(datasets));
    return std::make_shared<ConcatDataset>(all_datasets);
  }

  std::shared_ptr<FilterDataset> Filter(std::function<MSTensorVec(MSTensorVec)> predicate,
                                        const std::vector<std::string> &input_columns = {}) {
    return std::make_shared<FilterDataset>(shared_from_this(), predicate, VectorStringToChar(input_columns));
  }

  std::shared_ptr<MapDataset> Map(std::vector<TensorTransform *> operations,
                                  const std::vector<std::string> &input_columns = {},
                                  const std::vector<std::string> &output_columns = {},
                                  const std::vector<std::string> &project_columns = {},
                                  const std::shared_ptr<DatasetCache> &cache = nullptr,
                                  std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
    std::vector<std::shared_ptr<TensorOperation>> transform_ops;
    (void)std::transform(
      operations.begin(), operations.end(), std::back_inserter(transform_ops),
      [](TensorTransform *op) -> std::shared_ptr<TensorOperation> { return op != nullptr ? op->Parse() : nullptr; });
    return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
                                        VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
                                        callbacks);
  }

  std::shared_ptr<MapDataset> Map(std::vector<std::shared_ptr<TensorTransform>> operations,
                                  const std::vector<std::string> &input_columns = {},
                                  const std::vector<std::string> &output_columns = {},
                                  const std::vector<std::string> &project_columns = {},
                                  const std::shared_ptr<DatasetCache> &cache = nullptr,
                                  std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
    std::vector<std::shared_ptr<TensorOperation>> transform_ops;
    (void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
                         [](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
                           return op != nullptr ? op->Parse() : nullptr;
                         });
    return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
                                        VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
                                        callbacks);
  }

  std::shared_ptr<MapDataset> Map(const std::vector<std::reference_wrapper<TensorTransform>> operations,
                                  const std::vector<std::string> &input_columns = {},
                                  const std::vector<std::string> &output_columns = {},
                                  const std::vector<std::string> &project_columns = {},
                                  const std::shared_ptr<DatasetCache> &cache = nullptr,
                                  std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
    std::vector<std::shared_ptr<TensorOperation>> transform_ops;
    (void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
                         [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
    return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
                                        VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
                                        callbacks);
  }

  std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns) {
    return std::make_shared<ProjectDataset>(shared_from_this(), VectorStringToChar(columns));
  }

  std::shared_ptr<RenameDataset> Rename(const std::vector<std::string> &input_columns,
                                        const std::vector<std::string> &output_columns) {
    return std::make_shared<RenameDataset>(shared_from_this(), VectorStringToChar(input_columns),
                                           VectorStringToChar(output_columns));
  }
  std::shared_ptr<RepeatDataset> Repeat(int32_t count = -1) {
    return std::make_shared<RepeatDataset>(shared_from_this(), count);
  }
  std::shared_ptr<ShuffleDataset> Shuffle(int32_t buffer_size) {
    return std::make_shared<ShuffleDataset>(shared_from_this(), buffer_size);
  }

  std::shared_ptr<SkipDataset> Skip(int32_t count) { return std::make_shared<SkipDataset>(shared_from_this(), count); }

  std::shared_ptr<TakeDataset> Take(int32_t count = -1) {
    return std::make_shared<TakeDataset>(shared_from_this(), count);
  }

  std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
    std::vector<std::shared_ptr<Dataset>> all_datasets = datasets;
    all_datasets.push_back(shared_from_this());
    return std::make_shared<ZipDataset>(all_datasets);
  }

  std::shared_ptr<DatasetNode> IRNode() { return ir_node_; }

 protected:
  std::shared_ptr<TreeGetters> tree_getters_;
  std::shared_ptr<DatasetNode> ir_node_;

 private:
  // Char interface(CharIF) of GetColumnNames
  std::vector<std::vector<char>> GetColumnNamesCharIF();

  // Char interface(CharIF) of GetClassIndexing
  std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> GetClassIndexingCharIF();

  // Char interface(CharIF) of CreateIterator
  std::shared_ptr<Iterator> CreateIteratorCharIF(std::vector<std::vector<char>> columns, int32_t num_epochs);

  // Char interface(CharIF) of DeviceQueue
  bool DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type, int32_t device_id,
                         int32_t num_epochs, bool send_epoch_end, int32_t total_batches, bool create_data_info_queue);

  // Char interface(CharIF) of Save
  bool SaveCharIF(const std::vector<char> &dataset_path, int32_t num_files, const std::vector<char> &dataset_type);

  // Char interface(CharIF) of BuildSentencePieceVocab
  std::shared_ptr<SentencePieceVocab> BuildSentencePieceVocabCharIF(
    const std::vector<std::vector<char>> &col_names, int32_t vocab_size, float character_coverage,
    SentencePieceModel model_type, const std::map<std::vector<char>, std::vector<char>> &params);

  // Char interface(CharIF) of BuildVocab
  std::shared_ptr<Vocab> BuildVocabCharIF(const std::vector<std::vector<char>> &columns,
                                          const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
                                          const std::vector<std::vector<char>> &special_tokens, bool special_first);
};

class  SchemaObj {
 public:
  explicit SchemaObj(const std::string &schema_file = "") : SchemaObj(StringToChar(schema_file)) {}

  ~SchemaObj() = default;

  Status Init();

  Status add_column(const std::string &name, mindspore::DataType ms_type) {
    return add_column_char(StringToChar(name), ms_type);
  }

  Status add_column(const std::string &name, const std::string &ms_type) {
    return add_column_char(StringToChar(name), StringToChar(ms_type));
  }

  Status add_column(const std::string &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape) {
    return add_column_char(StringToChar(name), ms_type, shape);
  }

  Status add_column(const std::string &name, const std::string &ms_type, const std::vector<int32_t> &shape) {
    return add_column_char(StringToChar(name), StringToChar(ms_type), shape);
  }

  std::string to_json() { return CharToString(to_json_char()); }

  Status schema_to_json(nlohmann::json *out_json);

  std::string to_string() { return to_json(); }

  void set_dataset_type(std::string dataset_type);

  void set_num_rows(int32_t num_rows);

  int32_t get_num_rows() const;

  Status from_json(nlohmann::json json_obj);

  Status FromJSONString(const std::string &json_string) { return FromJSONStringCharIF(StringToChar(json_string)); }

  Status ParseColumnString(const std::string &json_string) {
    return ParseColumnStringCharIF(StringToChar(json_string));
  }

 private:
  Status parse_column(nlohmann::json columns);

  // Char constructor of SchemaObj
  explicit SchemaObj(const std::vector<char> &schema_file);

  // Char interface of add_column
  Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type);

  Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type);

  Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape);

  Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type,
                         const std::vector<int32_t> &shape);

  // Char interface of to_json
  const std::vector<char> to_json_char();

  // Char interface of FromJSONString
  Status FromJSONStringCharIF(const std::vector<char> &json_string);

  // Char interface of ParseColumnString
  Status ParseColumnStringCharIF(const std::vector<char> &json_string);

  struct Data;
  std::shared_ptr<Data> data_;
};

class  BatchDataset : public Dataset {
 public:
  BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder = false);

  ~BatchDataset() = default;
};

class  BucketBatchByLengthDataset : public Dataset {
 public:
  BucketBatchByLengthDataset(
    std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &column_names,
    const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
    std::function<MSTensorVec(MSTensorVec)> element_length_function = nullptr,
    const std::map<std::vector<char>, std::pair<std::vector<int64_t>, MSTensor>> &pad_info = {},
    bool pad_to_bucket_boundary = false, bool drop_remainder = false);

  ~BucketBatchByLengthDataset() = default;
};

class  ConcatDataset : public Dataset {
 public:
  explicit ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &input);

  ~ConcatDataset() = default;
};

class  FilterDataset : public Dataset {
 public:
  FilterDataset(std::shared_ptr<Dataset> input, std::function<MSTensorVec(MSTensorVec)> predicate,
                const std::vector<std::vector<char>> &input_columns);

  ~FilterDataset() = default;
};

class  MapDataset : public Dataset {
 public:
  MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations,
             const std::vector<std::vector<char>> &input_columns, const std::vector<std::vector<char>> &output_columns,
             const std::vector<std::vector<char>> &project_columns, const std::shared_ptr<DatasetCache> &cache,
             std::vector<std::shared_ptr<DSCallback>> callbacks);

  ~MapDataset() = default;
};

class  ProjectDataset : public Dataset {
 public:
  ProjectDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &columns);

  ~ProjectDataset() = default;
};

class  RenameDataset : public Dataset {
 public:
  RenameDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &input_columns,
                const std::vector<std::vector<char>> &output_columns);

  ~RenameDataset() = default;
};

class  RepeatDataset : public Dataset {
 public:
  RepeatDataset(std::shared_ptr<Dataset> input, int32_t count);

  ~RepeatDataset() = default;
};

class  ShuffleDataset : public Dataset {
 public:
  ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_size);

  ~ShuffleDataset() = default;
};

class  SkipDataset : public Dataset {
 public:
  SkipDataset(std::shared_ptr<Dataset> input, int32_t count);

  ~SkipDataset() = default;
};

class  TakeDataset : public Dataset {
 public:
  TakeDataset(std::shared_ptr<Dataset> input, int32_t count);

  ~TakeDataset() = default;
};

class  ZipDataset : public Dataset {
 public:
  explicit ZipDataset(const std::vector<std::shared_ptr<Dataset>> &inputs);

  ~ZipDataset() = default;
};

std::shared_ptr<SchemaObj>  SchemaCharIF(const std::vector<char> &schema_file);

inline std::shared_ptr<SchemaObj>  Schema(const std::string &schema_file = "") {
  return SchemaCharIF(StringToChar(schema_file));
}

class  AlbumDataset : public Dataset {
 public:
  AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
               const std::vector<std::vector<char>> &column_names, bool decode, const std::shared_ptr<Sampler> &sampler,
               const std::shared_ptr<DatasetCache> &cache);

  AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
               const std::vector<std::vector<char>> &column_names, bool decode, const Sampler *sampler,
               const std::shared_ptr<DatasetCache> &cache);

  AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
               const std::vector<std::vector<char>> &column_names, bool decode,
               const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);

  ~AlbumDataset() = default;
};

inline std::shared_ptr<AlbumDataset>
Album(const std::string &dataset_dir, const std::string &data_schema, const std::vector<std::string> &column_names = {},
      bool decode = false, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
      const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
                                        VectorStringToChar(column_names), decode, sampler, cache);
}
inline std::shared_ptr<AlbumDataset>  Album(const std::string &dataset_dir, const std::string &data_schema,
                                                  const std::vector<std::string> &column_names, bool decode,
                                                  const Sampler *sampler,
                                                  const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
                                        VectorStringToChar(column_names), decode, sampler, cache);
}
inline std::shared_ptr<AlbumDataset>  Album(const std::string &dataset_dir, const std::string &data_schema,
                                                  const std::vector<std::string> &column_names, bool decode,
                                                  const std::reference_wrapper<Sampler> sampler,
                                                  const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
                                        VectorStringToChar(column_names), decode, sampler, cache);
}

class  CelebADataset : public Dataset {
 public:
  CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
                const std::shared_ptr<Sampler> &sampler, bool decode, const std::set<std::vector<char>> &extensions,
                const std::shared_ptr<DatasetCache> &cache);

  CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
                bool decode, const std::set<std::vector<char>> &extensions, const std::shared_ptr<DatasetCache> &cache);

  CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
                const std::reference_wrapper<Sampler> sampler, bool decode,
                const std::set<std::vector<char>> &extensions, const std::shared_ptr<DatasetCache> &cache);

  ~CelebADataset() = default;
};

inline std::shared_ptr<CelebADataset>
CelebA(const std::string &dataset_dir, const std::string &usage = "all",
       const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(), bool decode = false,
       const std::set<std::string> &extensions = {}, const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<CelebADataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, decode,
                                         SetStringToChar(extensions), cache);
}

inline std::shared_ptr<CelebADataset>  CelebA(const std::string &dataset_dir, const std::string &usage,
                                                    const Sampler *sampler, bool decode = false,
                                                    const std::set<std::string> &extensions = {},
                                                    const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<CelebADataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, decode,
                                         SetStringToChar(extensions), cache);
}

inline std::shared_ptr<CelebADataset>  CelebA(const std::string &dataset_dir, const std::string &usage,
                                                    const std::reference_wrapper<Sampler> sampler, bool decode = false,
                                                    const std::set<std::string> &extensions = {},
                                                    const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<CelebADataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, decode,
                                         SetStringToChar(extensions), cache);
}

class  Cifar10Dataset : public Dataset {
 public:
  Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
                 const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);

  Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
                 const std::shared_ptr<DatasetCache> &cache);

  Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
                 const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);

  ~Cifar10Dataset() = default;
};

inline std::shared_ptr<Cifar10Dataset>
Cifar10(const std::string &dataset_dir, const std::string &usage = "all",
        const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
        const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<Cifar10Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}

inline std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::string &usage,
                                               const Sampler *sampler,
                                               const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<Cifar10Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}

inline std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::string &usage,
                                               const std::reference_wrapper<Sampler> sampler,
                                               const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<Cifar10Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}

class  Cifar100Dataset : public Dataset {
 public:
  Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
                  const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);

  Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
                  const std::shared_ptr<DatasetCache> &cache);

  Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
                  const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);

  ~Cifar100Dataset() = default;
};

inline std::shared_ptr<Cifar100Dataset>
Cifar100(const std::string &dataset_dir, const std::string &usage = "all",
         const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
         const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<Cifar100Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}

inline std::shared_ptr<Cifar100Dataset>  Cifar100(const std::string &dataset_dir, const std::string &usage,
                                                        const Sampler *sampler,
                                                        const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<Cifar100Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}

inline std::shared_ptr<Cifar100Dataset>  Cifar100(const std::string &dataset_dir, const std::string &usage,
                                                        const std::reference_wrapper<Sampler> sampler,
                                                        const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<Cifar100Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}

class  CityscapesDataset : public Dataset {
 public:
  CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
                    const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
                    const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);

  CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
                    const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
                    const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);

  CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
                    const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
                    const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);

  ~CityscapesDataset() = default;
};

inline std::shared_ptr<CityscapesDataset>  Cityscapes(
  const std::string &dataset_dir, const std::string &usage, const std::string &quality_mode, const std::string &task,
  bool decode = false, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
  const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<CityscapesDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(quality_mode),
                                             StringToChar(task), decode, sampler, cache);
}

inline std::shared_ptr<CityscapesDataset>  Cityscapes(const std::string &dataset_dir, const std::string &usage,
                                                            const std::string &quality_mode, const std::string &task,
                                                            bool decode, const Sampler *sampler,
                                                            const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<CityscapesDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(quality_mode),
                                             StringToChar(task), decode, sampler, cache);
}

inline std::shared_ptr<CityscapesDataset>  Cityscapes(const std::string &dataset_dir, const std::string &usage,
                                                            const std::string &quality_mode, const std::string &task,
                                                            bool decode, const std::reference_wrapper<Sampler> sampler,
                                                            const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<CityscapesDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(quality_mode),
                                             StringToChar(task), decode, sampler, cache);
}

class  CLUEDataset : public Dataset {
 public:
  CLUEDataset(const std::vector<std::vector<char>> &dataset_files, const std::vector<char> &task,
              const std::vector<char> &usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards,
              int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);

  ~CLUEDataset() = default;
};

inline std::shared_ptr<CLUEDataset>  CLUE(const std::vector<std::string> &dataset_files,
                                                const std::string &task = "AFQMC", const std::string &usage = "train",
                                                int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
                                                int32_t num_shards = 1, int32_t shard_id = 0,
                                                const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<CLUEDataset>(VectorStringToChar(dataset_files), StringToChar(task), StringToChar(usage),
                                       num_samples, shuffle, num_shards, shard_id, cache);
}

class  CocoDataset : public Dataset {
 public:
  CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
              const std::vector<char> &task, const bool &decode, const std::shared_ptr<Sampler> &sampler,
              const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata);

  CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
              const std::vector<char> &task, const bool &decode, const Sampler *sampler,
              const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata);

  CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
              const std::vector<char> &task, const bool &decode, const std::reference_wrapper<Sampler> sampler,
              const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata);

  ~CocoDataset() = default;
};

inline std::shared_ptr<CocoDataset>
Coco(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task = "Detection",
     const bool &decode = false, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
     const std::shared_ptr<DatasetCache> &cache = nullptr, const bool &extra_metadata = false) {
  return std::make_shared<CocoDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), StringToChar(task),
                                       decode, sampler, cache, extra_metadata);
}

inline std::shared_ptr<CocoDataset>  Coco(const std::string &dataset_dir, const std::string &annotation_file,
                                                const std::string &task, const bool &decode, const Sampler *sampler,
                                                const std::shared_ptr<DatasetCache> &cache = nullptr,
                                                const bool &extra_metadata = false) {
  return std::make_shared<CocoDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), StringToChar(task),
                                       decode, sampler, cache, extra_metadata);
}

inline std::shared_ptr<CocoDataset>  Coco(const std::string &dataset_dir, const std::string &annotation_file,
                                                const std::string &task, const bool &decode,
                                                const std::reference_wrapper<Sampler> sampler,
                                                const std::shared_ptr<DatasetCache> &cache = nullptr,
                                                const bool &extra_metadata = false) {
  return std::make_shared<CocoDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), StringToChar(task),
                                       decode, sampler, cache, extra_metadata);
}

class  CSVDataset : public Dataset {
 public:
  CSVDataset(const std::vector<std::vector<char>> &dataset_files, char field_delim,
             const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
             const std::vector<std::vector<char>> &column_names, int64_t num_samples, ShuffleMode shuffle,
             int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);

  ~CSVDataset() = default;
};

inline std::shared_ptr<CSVDataset>  CSV(const std::vector<std::string> &dataset_files, char field_delim = ',',
                                              const std::vector<std::shared_ptr<CsvBase>> &column_defaults = {},
                                              const std::vector<std::string> &column_names = {},
                                              int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
                                              int32_t num_shards = 1, int32_t shard_id = 0,
                                              const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<CSVDataset>(VectorStringToChar(dataset_files), field_delim, column_defaults,
                                      VectorStringToChar(column_names), num_samples, shuffle, num_shards, shard_id,
                                      cache);
}

class  DIV2KDataset : public Dataset {
 public:
  DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const std::vector<char> &downgrade,
               int32_t scale, bool decode, const std::shared_ptr<Sampler> &sampler,
               const std::shared_ptr<DatasetCache> &cache);

  DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const std::vector<char> &downgrade,
               int32_t scale, bool decode, const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);

  DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const std::vector<char> &downgrade,
               int32_t scale, bool decode, const std::reference_wrapper<Sampler> sampler,
               const std::shared_ptr<DatasetCache> &cache);

  ~DIV2KDataset() = default;
};

inline std::shared_ptr<DIV2KDataset>
DIV2K(const std::string &dataset_dir, const std::string &usage, const std::string &downgrade, int32_t scale,
      bool decode = false, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
      const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<DIV2KDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(downgrade), scale,
                                        decode, sampler, cache);
}

inline std::shared_ptr<DIV2KDataset>  DIV2K(const std::string &dataset_dir, const std::string &usage,
                                                  const std::string &downgrade, int32_t scale, bool decode,
                                                  const Sampler *sampler,
                                                  const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<DIV2KDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(downgrade), scale,
                                        decode, sampler, cache);
}

inline std::shared_ptr<DIV2KDataset>  DIV2K(const std::string &dataset_dir, const std::string &usage,
                                                  const std::string &downgrade, int32_t scale, bool decode,
                                                  const std::reference_wrapper<Sampler> sampler,
                                                  const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<DIV2KDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(downgrade), scale,
                                        decode, sampler, cache);
}

class  FlickrDataset : public Dataset {
 public:
  FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file, bool decode,
                const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);

  FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file, bool decode,
                const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);

  FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file, bool decode,
                const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);

  ~FlickrDataset() = default;
};

inline std::shared_ptr<FlickrDataset>
Flickr(const std::string &dataset_dir, const std::string &annotation_file, bool decode = false,
       const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
       const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<FlickrDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), decode, sampler,
                                         cache);
}

inline std::shared_ptr<FlickrDataset>  Flickr(const std::string &dataset_dir, const std::string &annotation_file,
                                                    bool decode, const Sampler *sampler,
                                                    const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<FlickrDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), decode, sampler,
                                         cache);
}

inline std::shared_ptr<FlickrDataset>  Flickr(const std::string &dataset_dir, const std::string &annotation_file,
                                                    bool decode, const std::reference_wrapper<Sampler> sampler,
                                                    const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<FlickrDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), decode, sampler,
                                         cache);
}

class  ImageFolderDataset : public Dataset {
 public:
  ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
                     const std::set<std::vector<char>> &extensions,
                     const std::map<std::vector<char>, int32_t> &class_indexing,
                     const std::shared_ptr<DatasetCache> &cache);

  ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
                     const std::set<std::vector<char>> &extensions,
                     const std::map<std::vector<char>, int32_t> &class_indexing,
                     const std::shared_ptr<DatasetCache> &cache);

  ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode, const std::reference_wrapper<Sampler> sampler,
                     const std::set<std::vector<char>> &extensions,
                     const std::map<std::vector<char>, int32_t> &class_indexing,
                     const std::shared_ptr<DatasetCache> &cache);

  ~ImageFolderDataset() = default;
};

inline std::shared_ptr<ImageFolderDataset>
ImageFolder(const std::string &dataset_dir, bool decode = false,
            const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
            const std::set<std::string> &extensions = {}, const std::map<std::string, int32_t> &class_indexing = {},
            const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<ImageFolderDataset>(StringToChar(dataset_dir), decode, sampler, SetStringToChar(extensions),
                                              MapStringToChar(class_indexing), cache);
}

inline std::shared_ptr<ImageFolderDataset>  ImageFolder(const std::string &dataset_dir, bool decode,
                                                              const Sampler *sampler,
                                                              const std::set<std::string> &extensions = {},
                                                              const std::map<std::string, int32_t> &class_indexing = {},
                                                              const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<ImageFolderDataset>(StringToChar(dataset_dir), decode, sampler, SetStringToChar(extensions),
                                              MapStringToChar(class_indexing), cache);
}

inline std::shared_ptr<ImageFolderDataset>  ImageFolder(const std::string &dataset_dir, bool decode,
                                                              const std::reference_wrapper<Sampler> sampler,
                                                              const std::set<std::string> &extensions = {},
                                                              const std::map<std::string, int32_t> &class_indexing = {},
                                                              const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<ImageFolderDataset>(StringToChar(dataset_dir), decode, sampler, SetStringToChar(extensions),
                                              MapStringToChar(class_indexing), cache);
}

class  ManifestDataset : public Dataset {
 public:
  ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
                  const std::shared_ptr<Sampler> &sampler, const std::map<std::vector<char>, int32_t> &class_indexing,
                  bool decode, const std::shared_ptr<DatasetCache> &cache);

  ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage, const Sampler *sampler,
                  const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
                  const std::shared_ptr<DatasetCache> &cache);

  ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
                  const std::reference_wrapper<Sampler> sampler,
                  const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
                  const std::shared_ptr<DatasetCache> &cache);

  ~ManifestDataset() = default;
};

inline std::shared_ptr<ManifestDataset>
Manifest(const std::string &dataset_file, const std::string &usage = "train",
         const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
         const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false,
         const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<ManifestDataset>(StringToChar(dataset_file), StringToChar(usage), sampler,
                                           MapStringToChar(class_indexing), decode, cache);
}

inline std::shared_ptr<ManifestDataset>  Manifest(const std::string &dataset_file, const std::string &usage,
                                                        const Sampler *sampler,
                                                        const std::map<std::string, int32_t> &class_indexing = {},
                                                        bool decode = false,
                                                        const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<ManifestDataset>(StringToChar(dataset_file), StringToChar(usage), sampler,
                                           MapStringToChar(class_indexing), decode, cache);
}

inline std::shared_ptr<ManifestDataset>  Manifest(const std::string &dataset_file, const std::string &usage,
                                                        const std::reference_wrapper<Sampler> sampler,
                                                        const std::map<std::string, int32_t> &class_indexing = {},
                                                        bool decode = false,
                                                        const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<ManifestDataset>(StringToChar(dataset_file), StringToChar(usage), sampler,
                                           MapStringToChar(class_indexing), decode, cache);
}

class  MindDataDataset : public Dataset {
 public:
  MindDataDataset(const std::vector<char> &dataset_file, const std::vector<std::vector<char>> &columns_list,
                  const std::shared_ptr<Sampler> &sampler, const nlohmann::json *padded_sample, int64_t num_padded,
                  ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
                  const std::shared_ptr<DatasetCache> &cache = nullptr);

  MindDataDataset(const std::vector<char> &dataset_file, const std::vector<std::vector<char>> &columns_list,
                  const Sampler *sampler, const nlohmann::json *padded_sample, int64_t num_padded,
                  ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
                  const std::shared_ptr<DatasetCache> &cache = nullptr);

  MindDataDataset(const std::vector<char> &dataset_file, const std::vector<std::vector<char>> &columns_list,
                  const std::reference_wrapper<Sampler> sampler, const nlohmann::json *padded_sample,
                  int64_t num_padded, ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
                  const std::shared_ptr<DatasetCache> &cache = nullptr);

  MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
                  const std::vector<std::vector<char>> &columns_list, const std::shared_ptr<Sampler> &sampler,
                  const nlohmann::json *padded_sample, int64_t num_padded,
                  ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
                  const std::shared_ptr<DatasetCache> &cache = nullptr);

  MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
                  const std::vector<std::vector<char>> &columns_list, const Sampler *sampler,
                  const nlohmann::json *padded_sample, int64_t num_padded,
                  ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
                  const std::shared_ptr<DatasetCache> &cache = nullptr);

  MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
                  const std::vector<std::vector<char>> &columns_list, const std::reference_wrapper<Sampler> sampler,
                  const nlohmann::json *padded_sample, int64_t num_padded,
                  ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
                  const std::shared_ptr<DatasetCache> &cache = nullptr);

  ~MindDataDataset() = default;
};

inline std::shared_ptr<MindDataDataset>
MindData(const std::string &dataset_file, const std::vector<std::string> &columns_list = {},
         const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
         nlohmann::json *padded_sample = nullptr, int64_t num_padded = 0,
         ShuffleMode shuffle_mode = ShuffleMode::kGlobal, const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<MindDataDataset>(StringToChar(dataset_file), VectorStringToChar(columns_list), sampler,
                                           padded_sample, num_padded, shuffle_mode, cache);
}

inline std::shared_ptr<MindDataDataset>  MindData(const std::string &dataset_file,
                                                        const std::vector<std::string> &columns_list,
                                                        const Sampler *sampler, nlohmann::json *padded_sample = nullptr,
                                                        int64_t num_padded = 0,
                                                        ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
                                                        const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<MindDataDataset>(StringToChar(dataset_file), VectorStringToChar(columns_list), sampler,
                                           padded_sample, num_padded, shuffle_mode, cache);
}
inline std::shared_ptr<MindDataDataset>  MindData(const std::string &dataset_file,
                                                        const std::vector<std::string> &columns_list,
                                                        const std::reference_wrapper<Sampler> sampler,
                                                        nlohmann::json *padded_sample = nullptr, int64_t num_padded = 0,
                                                        ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
                                                        const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<MindDataDataset>(StringToChar(dataset_file), VectorStringToChar(columns_list), sampler,
                                           padded_sample, num_padded, shuffle_mode, cache);
}

inline std::shared_ptr<MindDataDataset>
MindData(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list = {},
         const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
         nlohmann::json *padded_sample = nullptr, int64_t num_padded = 0,
         ShuffleMode shuffle_mode = ShuffleMode::kGlobal, const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<MindDataDataset>(VectorStringToChar(dataset_files), VectorStringToChar(columns_list), sampler,
                                           padded_sample, num_padded, shuffle_mode, cache);
}

inline std::shared_ptr<MindDataDataset>  MindData(const std::vector<std::string> &dataset_files,
                                                        const std::vector<std::string> &columns_list,
                                                        const Sampler *sampler, nlohmann::json *padded_sample = nullptr,
                                                        int64_t num_padded = 0,
                                                        ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
                                                        const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<MindDataDataset>(VectorStringToChar(dataset_files), VectorStringToChar(columns_list), sampler,
                                           padded_sample, num_padded, shuffle_mode, cache);
}

inline std::shared_ptr<MindDataDataset>  MindData(const std::vector<std::string> &dataset_files,
                                                        const std::vector<std::string> &columns_list,
                                                        const std::reference_wrapper<Sampler> sampler,
                                                        nlohmann::json *padded_sample = nullptr, int64_t num_padded = 0,
                                                        ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
                                                        const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<MindDataDataset>(VectorStringToChar(dataset_files), VectorStringToChar(columns_list), sampler,
                                           padded_sample, num_padded, shuffle_mode, cache);
}

class  MnistDataset : public Dataset {
 public:
  MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
               const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);

  MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
               const std::shared_ptr<DatasetCache> &cache);

  MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
               const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);

  ~MnistDataset() = default;
};

inline std::shared_ptr<MnistDataset>
Mnist(const std::string &dataset_dir, const std::string &usage = "all",
      const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
      const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}

inline std::shared_ptr<MnistDataset>  Mnist(const std::string &dataset_dir, const std::string &usage,
                                                  const Sampler *sampler,
                                                  const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}

inline std::shared_ptr<MnistDataset>  Mnist(const std::string &dataset_dir, const std::string &usage,
                                                  const std::reference_wrapper<Sampler> sampler,
                                                  const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}

inline std::shared_ptr<ConcatDataset>  operator+(const std::shared_ptr<Dataset> &datasets1,
                                                       const std::shared_ptr<Dataset> &datasets2) {
  return std::make_shared<ConcatDataset>(std::vector({datasets1, datasets2}));
}

class  RandomDataDataset : public Dataset {
 public:
  RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
                    const std::vector<std::vector<char>> &columns_list, std::shared_ptr<DatasetCache> cache);

  RandomDataDataset(const int32_t &total_rows, const std::vector<char> &schema_path,
                    const std::vector<std::vector<char>> &columns_list, std::shared_ptr<DatasetCache> cache);

  ~RandomDataDataset() = default;
};

template <typename T = std::shared_ptr<SchemaObj>>
std::shared_ptr<RandomDataDataset>  RandomData(const int32_t &total_rows = 0, const T &schema = nullptr,
                                                     const std::vector<std::string> &columns_list = {},
                                                     const std::shared_ptr<DatasetCache> &cache = nullptr) {
  std::shared_ptr<RandomDataDataset> ds;
  if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
    std::shared_ptr<SchemaObj> schema_obj = schema;
    ds =
      std::make_shared<RandomDataDataset>(total_rows, std::move(schema_obj), VectorStringToChar(columns_list), cache);
  } else {
    ds = std::make_shared<RandomDataDataset>(total_rows, StringToChar(schema), VectorStringToChar(columns_list), cache);
  }
  return ds;
}

class  SBUDataset : public Dataset {
 public:
  SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
             const std::shared_ptr<DatasetCache> &cache);

  SBUDataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
             const std::shared_ptr<DatasetCache> &cache);

  SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::reference_wrapper<Sampler> sampler,
             const std::shared_ptr<DatasetCache> &cache);

  ~SBUDataset() = default;
};

inline std::shared_ptr<SBUDataset>
SBU(const std::string &dataset_dir, bool decode = false,
    const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
    const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<SBUDataset>(StringToChar(dataset_dir), decode, sampler, cache);
}

inline std::shared_ptr<SBUDataset>  SBU(const std::string &dataset_dir, bool decode, const Sampler *sampler,
                                              const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<SBUDataset>(StringToChar(dataset_dir), decode, sampler, cache);
}

inline std::shared_ptr<SBUDataset>  SBU(const std::string &dataset_dir, bool decode,
                                              const std::reference_wrapper<Sampler> sampler,
                                              const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<SBUDataset>(StringToChar(dataset_dir), decode, sampler, cache);
}

class  TextFileDataset : public Dataset {
 public:
  TextFileDataset(const std::vector<std::vector<char>> &dataset_files, int64_t num_samples, ShuffleMode shuffle,
                  int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);

  ~TextFileDataset() = default;
};

inline std::shared_ptr<TextFileDataset>  TextFile(const std::vector<std::string> &dataset_files,
                                                        int64_t num_samples = 0,
                                                        ShuffleMode shuffle = ShuffleMode::kGlobal,
                                                        int32_t num_shards = 1, int32_t shard_id = 0,
                                                        const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<TextFileDataset>(VectorStringToChar(dataset_files), num_samples, shuffle, num_shards,
                                           shard_id, cache);
}

class  TFRecordDataset : public Dataset {
 public:
  TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, const std::vector<char> &schema,
                  const std::vector<std::vector<char>> &columns_list, int64_t num_samples, ShuffleMode shuffle,
                  int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache);

  TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, std::shared_ptr<SchemaObj> schema,
                  const std::vector<std::vector<char>> &columns_list, int64_t num_samples, ShuffleMode shuffle,
                  int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache);

  ~TFRecordDataset() = default;
};

template <typename T = std::shared_ptr<SchemaObj>>
std::shared_ptr<TFRecordDataset>  TFRecord(const std::vector<std::string> &dataset_files,
                                                 const T &schema = nullptr,
                                                 const std::vector<std::string> &columns_list = {},
                                                 int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
                                                 int32_t num_shards = 1, int32_t shard_id = 0,
                                                 bool shard_equal_rows = false,
                                                 const std::shared_ptr<DatasetCache> &cache = nullptr) {
  std::shared_ptr<TFRecordDataset> ds = nullptr;
  if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
    std::shared_ptr<SchemaObj> schema_obj = schema;
    ds = std::make_shared<TFRecordDataset>(VectorStringToChar(dataset_files), std::move(schema_obj),
                                           VectorStringToChar(columns_list), num_samples, shuffle, num_shards, shard_id,
                                           shard_equal_rows, cache);
  } else {
    std::string schema_path = schema;
    if (!schema_path.empty()) {
      struct stat sb;
      int rc = stat(schema_path.c_str(), &sb);
      if (rc != 0) {
        return nullptr;
      }
    }
    ds = std::make_shared<TFRecordDataset>(VectorStringToChar(dataset_files), StringToChar(schema_path),
                                           VectorStringToChar(columns_list), num_samples, shuffle, num_shards, shard_id,
                                           shard_equal_rows, cache);
  }
  return ds;
}

class  USPSDataset : public Dataset {
 public:
  USPSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
              ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);

  ~USPSDataset() = default;
};

inline std::shared_ptr<USPSDataset>  USPS(const std::string &dataset_dir, const std::string &usage = "all",
                                                int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
                                                int32_t num_shards = 1, int32_t shard_id = 0,
                                                const std::shared_ptr<DatasetCache> &cache = nullptr) {
  return std::make_shared<USPSDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle, num_shards,
                                       shard_id, cache);
}

class  VOCDataset : public Dataset {
 public:
  VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task, const std::vector<char> &usage,
             const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
             const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache, bool extra_metadata);

  VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task, const std::vector<char> &usage,
             const std::map<std::vector<char>, int32_t> &class_indexing, bool decode, const Sampler *sampler,
             const std::shared_ptr<DatasetCache> &cache, bool extra_metadata);

  VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task, const std::vector<char> &usage,
             const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
             const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache,
             bool extra_metadata);

  ~VOCDataset() = default;
};

inline std::shared_ptr<VOCDataset>
VOC(const std::string &dataset_dir, const std::string &task = "Segmentation", const std::string &usage = "train",
    const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false,
    const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
    const std::shared_ptr<DatasetCache> &cache = nullptr, bool extra_metadata = false) {
  return std::make_shared<VOCDataset>(StringToChar(dataset_dir), StringToChar(task), StringToChar(usage),
                                      MapStringToChar(class_indexing), decode, sampler, cache, extra_metadata);
}

inline std::shared_ptr<VOCDataset>  VOC(const std::string &dataset_dir, const std::string &task,
                                              const std::string &usage,
                                              const std::map<std::string, int32_t> &class_indexing, bool decode,
                                              const Sampler *sampler,
                                              const std::shared_ptr<DatasetCache> &cache = nullptr,
                                              bool extra_metadata = false) {
  return std::make_shared<VOCDataset>(StringToChar(dataset_dir), StringToChar(task), StringToChar(usage),
                                      MapStringToChar(class_indexing), decode, sampler, cache, extra_metadata);
}

inline std::shared_ptr<VOCDataset>  VOC(const std::string &dataset_dir, const std::string &task,
                                              const std::string &usage,
                                              const std::map<std::string, int32_t> &class_indexing, bool decode,
                                              const std::reference_wrapper<Sampler> sampler,
                                              const std::shared_ptr<DatasetCache> &cache = nullptr,
                                              bool extra_metadata = false) {
  return std::make_shared<VOCDataset>(StringToChar(dataset_dir), StringToChar(task), StringToChar(usage),
                                      MapStringToChar(class_indexing), decode, sampler, cache, extra_metadata);
}

std::shared_ptr<DatasetCache>  CreateDatasetCacheCharIF(session_id_type id, uint64_t mem_sz, bool spill,
                                                              std::optional<std::vector<char>> hostname = std::nullopt,
                                                              std::optional<int32_t> port = std::nullopt,
                                                              std::optional<int32_t> num_connections = std::nullopt,
                                                              std::optional<int32_t> prefetch_sz = std::nullopt);

inline std::shared_ptr<DatasetCache>  CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill,
                                                               std::optional<std::string> hostname = std::nullopt,
                                                               std::optional<int32_t> port = std::nullopt,
                                                               std::optional<int32_t> num_connections = std::nullopt,
                                                               std::optional<int32_t> prefetch_sz = std::nullopt) {
  std::optional<std::vector<char>> hostname_c = std::nullopt;
  if (hostname != std::nullopt) {
    hostname_c = std::vector<char>(hostname->begin(), hostname->end());
  }
  return CreateDatasetCacheCharIF(id, mem_sz, spill, hostname_c, port, num_connections, prefetch_sz);
}

inline std::shared_ptr<ZipDataset>  Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
  return std::make_shared<ZipDataset>(datasets);
}
}  // namespace dataset
}  // namespace mindspore

#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_