Program Listing for File datasets.h
↰ Return to documentation for file (include/datasets.h
)
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_
#include <sys/stat.h>
#include <unistd.h>
#include <algorithm>
#include <functional>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "include/api/dual_abi_helper.h"
#include "include/api/types.h"
#include "include/dataset/iterator.h"
#include "include/dataset/json_fwd.hpp"
#include "include/dataset/samplers.h"
#include "include/dataset/text.h"
namespace mindspore {
namespace dataset {
class Tensor;
class TensorShape;
class TreeAdapter;
class TreeAdapterLite;
class TreeGetters;
class Vocab;
class DatasetCache;
class DatasetNode;
class Iterator;
class PullBasedIterator;
class TensorOperation;
class SchemaObj;
class SamplerObj;
class CsvBase;
// Dataset classes (in alphabetical order)
class BatchDataset;
class MapDataset;
class ProjectDataset;
class ShuffleDataset;
class BucketBatchByLengthDataset;
class FilterDataset;
class CSVDataset;
class TransferDataset;
class ConcatDataset;
class RenameDataset;
class SentencePieceVocab;
enum class SentencePieceModel;
class DSCallback;
class RepeatDataset;
class SkipDataset;
class TakeDataset;
class ZipDataset;
class Dataset : public std::enable_shared_from_this<Dataset> {
public:
// need friend class so they can access the children_ field
friend class Iterator;
friend class TransferNode;
Dataset();
virtual ~Dataset() = default;
int64_t GetDatasetSize(bool estimate = false);
std::vector<mindspore::DataType> GetOutputTypes();
std::vector<std::vector<int64_t>> GetOutputShapes();
int64_t GetBatchSize();
int64_t GetRepeatCount();
int64_t GetNumClasses();
std::vector<std::string> GetColumnNames() { return VectorCharToString(GetColumnNamesCharIF()); }
std::vector<std::pair<std::string, std::vector<int32_t>>> GetClassIndexing() {
return ClassIndexCharToString(GetClassIndexingCharIF());
}
std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers);
std::shared_ptr<PullIterator> CreatePullBasedIterator(const std::vector<std::vector<char>> &columns = {});
std::shared_ptr<Iterator> CreateIterator(const std::vector<std::string> &columns = {}, int32_t num_epochs = -1) {
return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs);
}
bool DeviceQueue(const std::string &queue_name = "", const std::string &device_type = "", int32_t device_id = 0,
int32_t num_epochs = -1, bool send_epoch_end = true, int32_t total_batches = 0,
bool create_data_info_queue = false) {
return DeviceQueueCharIF(StringToChar(queue_name), StringToChar(device_type), device_id, num_epochs, send_epoch_end,
total_batches, create_data_info_queue);
}
bool Save(const std::string &dataset_path, int32_t num_files = 1, const std::string &dataset_type = "mindrecord") {
return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type));
}
std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);
std::shared_ptr<BucketBatchByLengthDataset> BucketBatchByLength(
const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries,
const std::vector<int32_t> &bucket_batch_sizes,
const std::function<MSTensorVec(MSTensorVec)> &element_length_function = nullptr,
const std::map<std::string, std::pair<std::vector<int64_t>, MSTensor>> &pad_info = {},
bool pad_to_bucket_boundary = false, bool drop_remainder = false) {
return std::make_shared<BucketBatchByLengthDataset>(
shared_from_this(), VectorStringToChar(column_names), bucket_boundaries, bucket_batch_sizes,
element_length_function, PadInfoStringToChar(pad_info), pad_to_bucket_boundary, drop_remainder);
}
std::shared_ptr<SentencePieceVocab> BuildSentencePieceVocab(
const std::vector<std::string> &col_names, int32_t vocab_size, float character_coverage,
SentencePieceModel model_type, const std::unordered_map<std::string, std::string> ¶ms) {
return BuildSentencePieceVocabCharIF(VectorStringToChar(col_names), vocab_size, character_coverage, model_type,
UnorderedMapStringToChar(params));
}
std::shared_ptr<Vocab> BuildVocab(const std::vector<std::string> &columns = {},
const std::pair<int64_t, int64_t> &freq_range = {0, kDeMaxFreq},
int64_t top_k = kDeMaxTopk, const std::vector<std::string> &special_tokens = {},
bool special_first = true) {
return BuildVocabCharIF(VectorStringToChar(columns), freq_range, top_k, VectorStringToChar(special_tokens),
special_first);
}
std::shared_ptr<ConcatDataset> Concat(const std::vector<std::shared_ptr<Dataset>> &datasets) {
std::vector<std::shared_ptr<Dataset>> all_datasets{shared_from_this()};
all_datasets.insert(std::end(all_datasets), std::begin(datasets), std::end(datasets));
return std::make_shared<ConcatDataset>(all_datasets);
}
std::shared_ptr<FilterDataset> Filter(const std::function<MSTensorVec(MSTensorVec)> &predicate,
const std::vector<std::string> &input_columns = {}) {
return std::make_shared<FilterDataset>(shared_from_this(), predicate, VectorStringToChar(input_columns));
}
std::shared_ptr<MapDataset> Map(const std::vector<TensorTransform *> &operations,
const std::vector<std::string> &input_columns = {},
const std::vector<std::string> &output_columns = {},
const std::vector<std::string> &project_columns = {},
const std::shared_ptr<DatasetCache> &cache = nullptr,
const std::vector<std::shared_ptr<DSCallback>> &callbacks = {}) {
std::vector<std::shared_ptr<TensorOperation>> transform_ops;
(void)std::transform(
operations.begin(), operations.end(), std::back_inserter(transform_ops),
[](TensorTransform *op) -> std::shared_ptr<TensorOperation> { return op != nullptr ? op->Parse() : nullptr; });
return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
callbacks);
}
std::shared_ptr<MapDataset> Map(const std::vector<std::shared_ptr<TensorTransform>> &operations,
const std::vector<std::string> &input_columns = {},
const std::vector<std::string> &output_columns = {},
const std::vector<std::string> &project_columns = {},
const std::shared_ptr<DatasetCache> &cache = nullptr,
const std::vector<std::shared_ptr<DSCallback>> &callbacks = {}) {
std::vector<std::shared_ptr<TensorOperation>> transform_ops;
(void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
[](const std::shared_ptr<TensorTransform> &op) -> std::shared_ptr<TensorOperation> {
return op != nullptr ? op->Parse() : nullptr;
});
return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
callbacks);
}
std::shared_ptr<MapDataset> Map(const std::vector<std::reference_wrapper<TensorTransform>> &operations,
const std::vector<std::string> &input_columns = {},
const std::vector<std::string> &output_columns = {},
const std::vector<std::string> &project_columns = {},
const std::shared_ptr<DatasetCache> &cache = nullptr,
const std::vector<std::shared_ptr<DSCallback>> &callbacks = {}) {
std::vector<std::shared_ptr<TensorOperation>> transform_ops;
(void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
[](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
callbacks);
}
std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns) {
return std::make_shared<ProjectDataset>(shared_from_this(), VectorStringToChar(columns));
}
std::shared_ptr<RenameDataset> Rename(const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns) {
return std::make_shared<RenameDataset>(shared_from_this(), VectorStringToChar(input_columns),
VectorStringToChar(output_columns));
}
std::shared_ptr<RepeatDataset> Repeat(int32_t count = -1) {
return std::make_shared<RepeatDataset>(shared_from_this(), count);
}
std::shared_ptr<ShuffleDataset> Shuffle(int32_t buffer_size) {
return std::make_shared<ShuffleDataset>(shared_from_this(), buffer_size);
}
std::shared_ptr<SkipDataset> Skip(int32_t count) { return std::make_shared<SkipDataset>(shared_from_this(), count); }
std::shared_ptr<TakeDataset> Take(int32_t count = -1) {
return std::make_shared<TakeDataset>(shared_from_this(), count);
}
std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
std::vector<std::shared_ptr<Dataset>> all_datasets = datasets;
all_datasets.push_back(shared_from_this());
return std::make_shared<ZipDataset>(all_datasets);
}
std::shared_ptr<DatasetNode> IRNode() { return ir_node_; }
protected:
std::shared_ptr<TreeGetters> tree_getters_;
std::shared_ptr<DatasetNode> ir_node_;
private:
// Char interface(CharIF) of GetColumnNames
std::vector<std::vector<char>> GetColumnNamesCharIF();
// Char interface(CharIF) of GetClassIndexing
std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> GetClassIndexingCharIF();
// Char interface(CharIF) of CreateIterator
std::shared_ptr<Iterator> CreateIteratorCharIF(const std::vector<std::vector<char>> &columns, int32_t num_epochs);
// Char interface(CharIF) of DeviceQueue
bool DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type, int32_t device_id,
int32_t num_epochs, bool send_epoch_end, int32_t total_batches, bool create_data_info_queue);
// Char interface(CharIF) of Save
bool SaveCharIF(const std::vector<char> &dataset_path, int32_t num_files, const std::vector<char> &dataset_type);
// Char interface(CharIF) of BuildSentencePieceVocab
std::shared_ptr<SentencePieceVocab> BuildSentencePieceVocabCharIF(
const std::vector<std::vector<char>> &col_names, int32_t vocab_size, float character_coverage,
SentencePieceModel model_type, const std::map<std::vector<char>, std::vector<char>> ¶ms);
// Char interface(CharIF) of BuildVocab
std::shared_ptr<Vocab> BuildVocabCharIF(const std::vector<std::vector<char>> &columns,
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
const std::vector<std::vector<char>> &special_tokens, bool special_first);
};
class SchemaObj {
public:
explicit SchemaObj(const std::string &schema_file = "") : SchemaObj(StringToChar(schema_file)) {}
~SchemaObj() = default;
Status Init();
Status add_column(const std::string &name, mindspore::DataType ms_type) {
return add_column_char(StringToChar(name), ms_type);
}
Status add_column(const std::string &name, const std::string &ms_type) {
return add_column_char(StringToChar(name), StringToChar(ms_type));
}
Status add_column(const std::string &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape) {
return add_column_char(StringToChar(name), ms_type, shape);
}
Status add_column(const std::string &name, const std::string &ms_type, const std::vector<int32_t> &shape) {
return add_column_char(StringToChar(name), StringToChar(ms_type), shape);
}
std::string to_json() { return CharToString(to_json_char()); }
Status schema_to_json(nlohmann::json *out_json);
std::string to_string() { return to_json(); }
void set_dataset_type(const std::string &dataset_type);
void set_num_rows(int32_t num_rows);
int32_t get_num_rows() const;
Status from_json(nlohmann::json json_obj);
Status FromJSONString(const std::string &json_string) { return FromJSONStringCharIF(StringToChar(json_string)); }
Status ParseColumnString(const std::string &json_string) {
return ParseColumnStringCharIF(StringToChar(json_string));
}
private:
Status parse_column(nlohmann::json columns);
// Char constructor of SchemaObj
explicit SchemaObj(const std::vector<char> &schema_file);
// Char interface of add_column
Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type);
Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type);
Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape);
Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type,
const std::vector<int32_t> &shape);
// Char interface of to_json
std::vector<char> to_json_char();
// Char interface of FromJSONString
Status FromJSONStringCharIF(const std::vector<char> &json_string);
// Char interface of ParseColumnString
Status ParseColumnStringCharIF(const std::vector<char> &json_string);
struct Data;
std::shared_ptr<Data> data_;
};
class BatchDataset : public Dataset {
public:
BatchDataset(const std::shared_ptr<Dataset> &input, int32_t batch_size, bool drop_remainder = false);
~BatchDataset() override = default;
};
class BucketBatchByLengthDataset : public Dataset {
public:
BucketBatchByLengthDataset(
const std::shared_ptr<Dataset> &input, const std::vector<std::vector<char>> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
const std::function<MSTensorVec(MSTensorVec)> &element_length_function = nullptr,
const std::map<std::vector<char>, std::pair<std::vector<int64_t>, MSTensor>> &pad_info = {},
bool pad_to_bucket_boundary = false, bool drop_remainder = false);
~BucketBatchByLengthDataset() override = default;
};
class ConcatDataset : public Dataset {
public:
explicit ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &input);
~ConcatDataset() override = default;
};
class FilterDataset : public Dataset {
public:
FilterDataset(const std::shared_ptr<Dataset> &input, const std::function<MSTensorVec(MSTensorVec)> &predicate,
const std::vector<std::vector<char>> &input_columns);
~FilterDataset() override = default;
};
class MapDataset : public Dataset {
public:
MapDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::shared_ptr<TensorOperation>> &operations,
const std::vector<std::vector<char>> &input_columns, const std::vector<std::vector<char>> &output_columns,
const std::vector<std::vector<char>> &project_columns, const std::shared_ptr<DatasetCache> &cache,
const std::vector<std::shared_ptr<DSCallback>> &callbacks);
~MapDataset() override = default;
};
class ProjectDataset : public Dataset {
public:
ProjectDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::vector<char>> &columns);
~ProjectDataset() override = default;
};
class RenameDataset : public Dataset {
public:
RenameDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::vector<char>> &input_columns,
const std::vector<std::vector<char>> &output_columns);
~RenameDataset() override = default;
};
class RepeatDataset : public Dataset {
public:
RepeatDataset(const std::shared_ptr<Dataset> &input, int32_t count);
~RepeatDataset() override = default;
};
class ShuffleDataset : public Dataset {
public:
ShuffleDataset(const std::shared_ptr<Dataset> &input, int32_t buffer_size);
~ShuffleDataset() override = default;
};
class SkipDataset : public Dataset {
public:
SkipDataset(const std::shared_ptr<Dataset> &input, int32_t count);
~SkipDataset() override = default;
};
class TakeDataset : public Dataset {
public:
TakeDataset(const std::shared_ptr<Dataset> &input, int32_t count);
~TakeDataset() override = default;
};
class ZipDataset : public Dataset {
public:
explicit ZipDataset(const std::vector<std::shared_ptr<Dataset>> &inputs);
~ZipDataset() override = default;
};
std::shared_ptr<SchemaObj> SchemaCharIF(const std::vector<char> &schema_file);
inline std::shared_ptr<SchemaObj> Schema(const std::string &schema_file = "") {
return SchemaCharIF(StringToChar(schema_file));
}
class AGNewsDataset : public Dataset {
public:
AGNewsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
~AGNewsDataset() override = default;
};
inline std::shared_ptr<AGNewsDataset> AGNews(const std::string &dataset_dir, const std::string &usage = "all",
int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<AGNewsDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
num_shards, shard_id, cache);
}
class AlbumDataset : public Dataset {
public:
AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
const std::vector<std::vector<char>> &column_names, bool decode, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
const std::vector<std::vector<char>> &column_names, bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
const std::vector<std::vector<char>> &column_names, bool decode,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~AlbumDataset() override = default;
};
inline std::shared_ptr<AlbumDataset>
Album(const std::string &dataset_dir, const std::string &data_schema, const std::vector<std::string> &column_names = {},
bool decode = false, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
VectorStringToChar(column_names), decode, sampler, cache);
}
inline std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
const std::vector<std::string> &column_names, bool decode,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
VectorStringToChar(column_names), decode, sampler, cache);
}
inline std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
const std::vector<std::string> &column_names, bool decode,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
VectorStringToChar(column_names), decode, sampler, cache);
}
class AmazonReviewDataset : public Dataset {
public:
AmazonReviewDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache);
~AmazonReviewDataset() override = default;
};
inline std::shared_ptr<AmazonReviewDataset> AmazonReview(const std::string &dataset_dir,
const std::string &usage = "all",
int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<AmazonReviewDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
num_shards, shard_id, cache);
}
class Caltech256Dataset : public Dataset {
public:
Caltech256Dataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
Caltech256Dataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
Caltech256Dataset(const std::vector<char> &dataset_dir, bool decode, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
~Caltech256Dataset() override = default;
};
inline std::shared_ptr<Caltech256Dataset>
Caltech256(const std::string &dataset_dir, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
bool decode = false, const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<Caltech256Dataset>(StringToChar(dataset_dir), decode, sampler, cache);
}
inline std::shared_ptr<Caltech256Dataset> Caltech256(const std::string &dataset_dir, const Sampler *sampler,
bool decode = false,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<Caltech256Dataset>(StringToChar(dataset_dir), decode, sampler, cache);
}
inline std::shared_ptr<Caltech256Dataset> Caltech256(const std::string &dataset_dir,
const std::reference_wrapper<Sampler> &sampler,
bool decode = false,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<Caltech256Dataset>(StringToChar(dataset_dir), decode, sampler, cache);
}
class CelebADataset : public Dataset {
public:
CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, bool decode, const std::set<std::vector<char>> &extensions,
const std::shared_ptr<DatasetCache> &cache);
CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
bool decode, const std::set<std::vector<char>> &extensions, const std::shared_ptr<DatasetCache> &cache);
CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> &sampler, bool decode,
const std::set<std::vector<char>> &extensions, const std::shared_ptr<DatasetCache> &cache);
~CelebADataset() override = default;
};
inline std::shared_ptr<CelebADataset>
CelebA(const std::string &dataset_dir, const std::string &usage = "all",
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(), bool decode = false,
const std::set<std::string> &extensions = {}, const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<CelebADataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, decode,
SetStringToChar(extensions), cache);
}
inline std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &usage,
const Sampler *sampler, bool decode = false,
const std::set<std::string> &extensions = {},
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<CelebADataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, decode,
SetStringToChar(extensions), cache);
}
inline std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &usage,
const std::reference_wrapper<Sampler> &sampler, bool decode = false,
const std::set<std::string> &extensions = {},
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<CelebADataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, decode,
SetStringToChar(extensions), cache);
}
class Cifar10Dataset : public Dataset {
public:
Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
Cifar10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~Cifar10Dataset() override = default;
};
inline std::shared_ptr<Cifar10Dataset>
Cifar10(const std::string &dataset_dir, const std::string &usage = "all",
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<Cifar10Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
inline std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::string &usage,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<Cifar10Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
inline std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::string &usage,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<Cifar10Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
class Cifar100Dataset : public Dataset {
public:
Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
Cifar100Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~Cifar100Dataset() override = default;
};
inline std::shared_ptr<Cifar100Dataset>
Cifar100(const std::string &dataset_dir, const std::string &usage = "all",
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<Cifar100Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
inline std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::string &usage,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<Cifar100Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
inline std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::string &usage,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<Cifar100Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
class CityscapesDataset : public Dataset {
public:
CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
CityscapesDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<char> &quality_mode, const std::vector<char> &task, bool decode,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~CityscapesDataset() override = default;
};
inline std::shared_ptr<CityscapesDataset> Cityscapes(
const std::string &dataset_dir, const std::string &usage, const std::string &quality_mode, const std::string &task,
bool decode = false, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<CityscapesDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(quality_mode),
StringToChar(task), decode, sampler, cache);
}
inline std::shared_ptr<CityscapesDataset> Cityscapes(const std::string &dataset_dir, const std::string &usage,
const std::string &quality_mode, const std::string &task,
bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<CityscapesDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(quality_mode),
StringToChar(task), decode, sampler, cache);
}
inline std::shared_ptr<CityscapesDataset> Cityscapes(const std::string &dataset_dir, const std::string &usage,
const std::string &quality_mode, const std::string &task,
bool decode, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<CityscapesDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(quality_mode),
StringToChar(task), decode, sampler, cache);
}
class CLUEDataset : public Dataset {
public:
CLUEDataset(const std::vector<std::vector<char>> &dataset_files, const std::vector<char> &task,
const std::vector<char> &usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards,
int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
~CLUEDataset() override = default;
};
inline std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &dataset_files,
const std::string &task = "AFQMC", const std::string &usage = "train",
int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<CLUEDataset>(VectorStringToChar(dataset_files), StringToChar(task), StringToChar(usage),
num_samples, shuffle, num_shards, shard_id, cache);
}
class CocoDataset : public Dataset {
public:
CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
const std::vector<char> &task, const bool &decode, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata);
CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
const std::vector<char> &task, const bool &decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata);
CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
const std::vector<char> &task, const bool &decode, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata);
~CocoDataset() override = default;
};
inline std::shared_ptr<CocoDataset>
Coco(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task = "Detection",
const bool &decode = false, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr, const bool &extra_metadata = false) {
return std::make_shared<CocoDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), StringToChar(task),
decode, sampler, cache, extra_metadata);
}
inline std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
const std::string &task, const bool &decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr,
const bool &extra_metadata = false) {
return std::make_shared<CocoDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), StringToChar(task),
decode, sampler, cache, extra_metadata);
}
inline std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
const std::string &task, const bool &decode,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr,
const bool &extra_metadata = false) {
return std::make_shared<CocoDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), StringToChar(task),
decode, sampler, cache, extra_metadata);
}
class CoNLL2000Dataset : public Dataset {
public:
CoNLL2000Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache);
~CoNLL2000Dataset() override = default;
};
inline std::shared_ptr<CoNLL2000Dataset> CoNLL2000(const std::string &dataset_dir,
const std::string &usage = "all", int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<CoNLL2000Dataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
num_shards, shard_id, cache);
}
class CSVDataset : public Dataset {
public:
CSVDataset(const std::vector<std::vector<char>> &dataset_files, char field_delim,
const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
const std::vector<std::vector<char>> &column_names, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
~CSVDataset() override = default;
};
inline std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_files, char field_delim = ',',
const std::vector<std::shared_ptr<CsvBase>> &column_defaults = {},
const std::vector<std::string> &column_names = {},
int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<CSVDataset>(VectorStringToChar(dataset_files), field_delim, column_defaults,
VectorStringToChar(column_names), num_samples, shuffle, num_shards, shard_id,
cache);
}
class DBpediaDataset : public Dataset {
public:
DBpediaDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
~DBpediaDataset() override = default;
};
inline std::shared_ptr<DBpediaDataset> DBpedia(const std::string &dataset_dir, const std::string &usage = "all",
int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<DBpediaDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
num_shards, shard_id, cache);
}
class DIV2KDataset : public Dataset {
public:
DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const std::vector<char> &downgrade,
int32_t scale, bool decode, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const std::vector<char> &downgrade,
int32_t scale, bool decode, const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
DIV2KDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const std::vector<char> &downgrade,
int32_t scale, bool decode, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
~DIV2KDataset() override = default;
};
inline std::shared_ptr<DIV2KDataset>
DIV2K(const std::string &dataset_dir, const std::string &usage, const std::string &downgrade, int32_t scale,
bool decode = false, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<DIV2KDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(downgrade), scale,
decode, sampler, cache);
}
inline std::shared_ptr<DIV2KDataset> DIV2K(const std::string &dataset_dir, const std::string &usage,
const std::string &downgrade, int32_t scale, bool decode,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<DIV2KDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(downgrade), scale,
decode, sampler, cache);
}
inline std::shared_ptr<DIV2KDataset> DIV2K(const std::string &dataset_dir, const std::string &usage,
const std::string &downgrade, int32_t scale, bool decode,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<DIV2KDataset>(StringToChar(dataset_dir), StringToChar(usage), StringToChar(downgrade), scale,
decode, sampler, cache);
}
class EMnistDataset : public Dataset {
public:
EMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
EMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name, const std::vector<char> &usage,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
EMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~EMnistDataset() override = default;
};
inline std::shared_ptr<EMnistDataset>
EMnist(const std::string &dataset_dir, const std::string &name, const std::string &usage = "all",
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<EMnistDataset>(StringToChar(dataset_dir), StringToChar(name), StringToChar(usage), sampler,
cache);
}
inline std::shared_ptr<EMnistDataset> EMnist(const std::string &dataset_dir, const std::string &usage,
const std::string &name, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<EMnistDataset>(StringToChar(dataset_dir), StringToChar(name), StringToChar(usage), sampler,
cache);
}
inline std::shared_ptr<EMnistDataset> EMnist(const std::string &dataset_dir, const std::string &name,
const std::string &usage,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<EMnistDataset>(StringToChar(dataset_dir), StringToChar(name), StringToChar(usage), sampler,
cache);
}
class EnWik9Dataset : public Dataset {
public:
EnWik9Dataset(const std::vector<char> &dataset_dir, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards,
int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
~EnWik9Dataset() override = default;
};
inline std::shared_ptr<EnWik9Dataset> EnWik9(const std::string &dataset_dir, int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<EnWik9Dataset>(StringToChar(dataset_dir), num_samples, shuffle, num_shards, shard_id, cache);
}
class FakeImageDataset : public Dataset {
public:
FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes, int32_t base_seed,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes, int32_t base_seed,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes, int32_t base_seed,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~FakeImageDataset() override = default;
};
inline std::shared_ptr<FakeImageDataset>
FakeImage(int32_t num_images = 1000, const std::vector<int32_t> &image_size = {224, 224, 3}, int32_t num_classes = 10,
int32_t base_seed = 0, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<FakeImageDataset>(num_images, image_size, num_classes, base_seed, sampler, cache);
}
inline std::shared_ptr<FakeImageDataset> FakeImage(int32_t num_images, const std::vector<int32_t> &image_size,
int32_t num_classes, int32_t base_seed,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<FakeImageDataset>(num_images, image_size, num_classes, base_seed, sampler, cache);
}
inline std::shared_ptr<FakeImageDataset> FakeImage(int32_t num_images, const std::vector<int32_t> &image_size,
int32_t num_classes, int32_t base_seed,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<FakeImageDataset>(num_images, image_size, num_classes, base_seed, sampler, cache);
}
class FashionMnistDataset : public Dataset {
public:
FashionMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
FashionMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
FashionMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~FashionMnistDataset() override = default;
};
inline std::shared_ptr<FashionMnistDataset>
FashionMnist(const std::string &dataset_dir, const std::string &usage = "all",
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<FashionMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
inline std::shared_ptr<FashionMnistDataset> FashionMnist(const std::string &dataset_dir,
const std::string &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<FashionMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
inline std::shared_ptr<FashionMnistDataset> FashionMnist(const std::string &dataset_dir,
const std::string &usage,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<FashionMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
class FlickrDataset : public Dataset {
public:
FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file, bool decode,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file, bool decode,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file, bool decode,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~FlickrDataset() override = default;
};
inline std::shared_ptr<FlickrDataset>
Flickr(const std::string &dataset_dir, const std::string &annotation_file, bool decode = false,
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<FlickrDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), decode, sampler,
cache);
}
inline std::shared_ptr<FlickrDataset> Flickr(const std::string &dataset_dir, const std::string &annotation_file,
bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<FlickrDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), decode, sampler,
cache);
}
inline std::shared_ptr<FlickrDataset> Flickr(const std::string &dataset_dir, const std::string &annotation_file,
bool decode, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<FlickrDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), decode, sampler,
cache);
}
class ImageFolderDataset : public Dataset {
public:
ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
const std::set<std::vector<char>> &extensions,
const std::map<std::vector<char>, int32_t> &class_indexing,
const std::shared_ptr<DatasetCache> &cache);
ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
const std::set<std::vector<char>> &extensions,
const std::map<std::vector<char>, int32_t> &class_indexing,
const std::shared_ptr<DatasetCache> &cache);
ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode, const std::reference_wrapper<Sampler> &sampler,
const std::set<std::vector<char>> &extensions,
const std::map<std::vector<char>, int32_t> &class_indexing,
const std::shared_ptr<DatasetCache> &cache);
~ImageFolderDataset() override = default;
};
inline std::shared_ptr<ImageFolderDataset>
ImageFolder(const std::string &dataset_dir, bool decode = false,
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::set<std::string> &extensions = {}, const std::map<std::string, int32_t> &class_indexing = {},
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<ImageFolderDataset>(StringToChar(dataset_dir), decode, sampler, SetStringToChar(extensions),
MapStringToChar(class_indexing), cache);
}
inline std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode,
const Sampler *sampler,
const std::set<std::string> &extensions = {},
const std::map<std::string, int32_t> &class_indexing = {},
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<ImageFolderDataset>(StringToChar(dataset_dir), decode, sampler, SetStringToChar(extensions),
MapStringToChar(class_indexing), cache);
}
inline std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode,
const std::reference_wrapper<Sampler> &sampler,
const std::set<std::string> &extensions = {},
const std::map<std::string, int32_t> &class_indexing = {},
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<ImageFolderDataset>(StringToChar(dataset_dir), decode, sampler, SetStringToChar(extensions),
MapStringToChar(class_indexing), cache);
}
class IMDBDataset : public Dataset {
public:
IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
IMDBDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~IMDBDataset() override = default;
};
inline std::shared_ptr<IMDBDataset>
IMDB(const std::string &dataset_dir, const std::string &usage = "all",
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<IMDBDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
inline std::shared_ptr<IMDBDataset> IMDB(const std::string &dataset_dir, const std::string &usage,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<IMDBDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
inline std::shared_ptr<IMDBDataset> IMDB(const std::string &dataset_dir, const std::string &usage,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<IMDBDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
class IWSLT2016Dataset : public Dataset {
public:
IWSLT2016Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<std::vector<char>> &language_pair, const std::vector<char> &valid_set,
const std::vector<char> &test_set, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards,
int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
~IWSLT2016Dataset() override = default;
};
inline std::shared_ptr<IWSLT2016Dataset>
IWSLT2016(const std::string &dataset_dir, const std::string &usage = "all",
const std::vector<std::string> &language_pair = {"de", "en"}, const std::string &valid_set = "tst2013",
const std::string &test_set = "tst2014", int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0, const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<IWSLT2016Dataset>(StringToChar(dataset_dir), StringToChar(usage),
VectorStringToChar(language_pair), StringToChar(valid_set),
StringToChar(test_set), num_samples, shuffle, num_shards, shard_id, cache);
}
class IWSLT2017Dataset : public Dataset {
public:
IWSLT2017Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<std::vector<char>> &language_pair, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
~IWSLT2017Dataset() override = default;
};
inline std::shared_ptr<IWSLT2017Dataset> IWSLT2017(const std::string &dataset_dir,
const std::string &usage = "all",
const std::vector<std::string> &language_pair = {"de", "en"},
int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<IWSLT2017Dataset>(StringToChar(dataset_dir), StringToChar(usage),
VectorStringToChar(language_pair), num_samples, shuffle, num_shards,
shard_id, cache);
}
class KMnistDataset : public Dataset {
public:
KMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
KMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
KMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~KMnistDataset() override = default;
};
inline std::shared_ptr<KMnistDataset>
KMnist(const std::string &dataset_dir, const std::string &usage = "all",
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<KMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
inline std::shared_ptr<KMnistDataset> KMnist(const std::string &dataset_dir, const std::string &usage,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<KMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
inline std::shared_ptr<KMnistDataset> KMnist(const std::string &dataset_dir, const std::string &usage,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<KMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
class LJSpeechDataset : public Dataset {
public:
LJSpeechDataset(const std::vector<char> &dataset_dir, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
LJSpeechDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
LJSpeechDataset(const std::vector<char> &dataset_dir, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
~LJSpeechDataset() override = default;
};
inline std::shared_ptr<LJSpeechDataset>
LJSpeech(const std::string &dataset_dir, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<LJSpeechDataset>(StringToChar(dataset_dir), sampler, cache);
}
inline std::shared_ptr<LJSpeechDataset> LJSpeech(const std::string &dataset_dir, Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<LJSpeechDataset>(StringToChar(dataset_dir), sampler, cache);
}
inline std::shared_ptr<LJSpeechDataset> LJSpeech(const std::string &dataset_dir,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<LJSpeechDataset>(StringToChar(dataset_dir), sampler, cache);
}
class ManifestDataset : public Dataset {
public:
ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::map<std::vector<char>, int32_t> &class_indexing,
bool decode, const std::shared_ptr<DatasetCache> &cache);
ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage, const Sampler *sampler,
const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
const std::shared_ptr<DatasetCache> &cache);
ManifestDataset(const std::vector<char> &dataset_file, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> &sampler,
const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
const std::shared_ptr<DatasetCache> &cache);
~ManifestDataset() override = default;
};
inline std::shared_ptr<ManifestDataset>
Manifest(const std::string &dataset_file, const std::string &usage = "train",
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<ManifestDataset>(StringToChar(dataset_file), StringToChar(usage), sampler,
MapStringToChar(class_indexing), decode, cache);
}
inline std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const std::string &usage,
const Sampler *sampler,
const std::map<std::string, int32_t> &class_indexing = {},
bool decode = false,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<ManifestDataset>(StringToChar(dataset_file), StringToChar(usage), sampler,
MapStringToChar(class_indexing), decode, cache);
}
inline std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const std::string &usage,
const std::reference_wrapper<Sampler> &sampler,
const std::map<std::string, int32_t> &class_indexing = {},
bool decode = false,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<ManifestDataset>(StringToChar(dataset_file), StringToChar(usage), sampler,
MapStringToChar(class_indexing), decode, cache);
}
class MindDataDataset : public Dataset {
public:
MindDataDataset(const std::vector<char> &dataset_file, const std::vector<std::vector<char>> &columns_list,
const std::shared_ptr<Sampler> &sampler, const nlohmann::json *padded_sample, int64_t num_padded,
ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
const std::shared_ptr<DatasetCache> &cache = nullptr);
MindDataDataset(const std::vector<char> &dataset_file, const std::vector<std::vector<char>> &columns_list,
const Sampler *sampler, const nlohmann::json *padded_sample, int64_t num_padded,
ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
const std::shared_ptr<DatasetCache> &cache = nullptr);
MindDataDataset(const std::vector<char> &dataset_file, const std::vector<std::vector<char>> &columns_list,
const std::reference_wrapper<Sampler> &sampler, const nlohmann::json *padded_sample,
int64_t num_padded, ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
const std::shared_ptr<DatasetCache> &cache = nullptr);
MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
const std::vector<std::vector<char>> &columns_list, const std::shared_ptr<Sampler> &sampler,
const nlohmann::json *padded_sample, int64_t num_padded,
ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
const std::shared_ptr<DatasetCache> &cache = nullptr);
MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
const std::vector<std::vector<char>> &columns_list, const Sampler *sampler,
const nlohmann::json *padded_sample, int64_t num_padded,
ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
const std::shared_ptr<DatasetCache> &cache = nullptr);
MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
const std::vector<std::vector<char>> &columns_list, const std::reference_wrapper<Sampler> &sampler,
const nlohmann::json *padded_sample, int64_t num_padded,
ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
const std::shared_ptr<DatasetCache> &cache = nullptr);
~MindDataDataset() override = default;
};
inline std::shared_ptr<MindDataDataset>
MindData(const std::string &dataset_file, const std::vector<std::string> &columns_list = {},
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
nlohmann::json *padded_sample = nullptr, int64_t num_padded = 0,
ShuffleMode shuffle_mode = ShuffleMode::kGlobal, const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<MindDataDataset>(StringToChar(dataset_file), VectorStringToChar(columns_list), sampler,
padded_sample, num_padded, shuffle_mode, cache);
}
inline std::shared_ptr<MindDataDataset> MindData(const std::string &dataset_file,
const std::vector<std::string> &columns_list,
const Sampler *sampler, nlohmann::json *padded_sample = nullptr,
int64_t num_padded = 0,
ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<MindDataDataset>(StringToChar(dataset_file), VectorStringToChar(columns_list), sampler,
padded_sample, num_padded, shuffle_mode, cache);
}
inline std::shared_ptr<MindDataDataset> MindData(const std::string &dataset_file,
const std::vector<std::string> &columns_list,
const std::reference_wrapper<Sampler> &sampler,
nlohmann::json *padded_sample = nullptr, int64_t num_padded = 0,
ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<MindDataDataset>(StringToChar(dataset_file), VectorStringToChar(columns_list), sampler,
padded_sample, num_padded, shuffle_mode, cache);
}
inline std::shared_ptr<MindDataDataset>
MindData(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list = {},
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
nlohmann::json *padded_sample = nullptr, int64_t num_padded = 0,
ShuffleMode shuffle_mode = ShuffleMode::kGlobal, const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<MindDataDataset>(VectorStringToChar(dataset_files), VectorStringToChar(columns_list), sampler,
padded_sample, num_padded, shuffle_mode, cache);
}
inline std::shared_ptr<MindDataDataset> MindData(const std::vector<std::string> &dataset_files,
const std::vector<std::string> &columns_list,
const Sampler *sampler, nlohmann::json *padded_sample = nullptr,
int64_t num_padded = 0,
ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<MindDataDataset>(VectorStringToChar(dataset_files), VectorStringToChar(columns_list), sampler,
padded_sample, num_padded, shuffle_mode, cache);
}
inline std::shared_ptr<MindDataDataset> MindData(const std::vector<std::string> &dataset_files,
const std::vector<std::string> &columns_list,
const std::reference_wrapper<Sampler> &sampler,
nlohmann::json *padded_sample = nullptr, int64_t num_padded = 0,
ShuffleMode shuffle_mode = ShuffleMode::kGlobal,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<MindDataDataset>(VectorStringToChar(dataset_files), VectorStringToChar(columns_list), sampler,
padded_sample, num_padded, shuffle_mode, cache);
}
class MnistDataset : public Dataset {
public:
MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~MnistDataset() override = default;
};
inline std::shared_ptr<MnistDataset>
Mnist(const std::string &dataset_dir, const std::string &usage = "all",
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
class PennTreebankDataset : public Dataset {
public:
PennTreebankDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache);
~PennTreebankDataset() override = default;
};
inline std::shared_ptr<PennTreebankDataset> PennTreebank(const std::string &dataset_dir,
const std::string &usage = "all",
int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<PennTreebankDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
num_shards, shard_id, cache);
}
class PhotoTourDataset : public Dataset {
public:
PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name, const std::vector<char> &usage,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
PhotoTourDataset(const std::vector<char> &dataset_dir, const std::vector<char> &name, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~PhotoTourDataset() override = default;
};
inline std::shared_ptr<PhotoTourDataset>
PhotoTour(const std::string &dataset_dir, const std::string &name, const std::string &usage = "train",
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<PhotoTourDataset>(StringToChar(dataset_dir), StringToChar(name), StringToChar(usage), sampler,
cache);
}
inline std::shared_ptr<PhotoTourDataset> PhotoTour(const std::string &dataset_dir, const std::string &name,
const std::string &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<PhotoTourDataset>(StringToChar(dataset_dir), StringToChar(name), StringToChar(usage), sampler,
cache);
}
inline std::shared_ptr<PhotoTourDataset> PhotoTour(const std::string &dataset_dir, const std::string &name,
const std::string &usage,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<PhotoTourDataset>(StringToChar(dataset_dir), StringToChar(name), StringToChar(usage), sampler,
cache);
}
class Places365Dataset : public Dataset {
public:
Places365Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool small, bool decode,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
Places365Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool small, bool decode,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
Places365Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool small, bool decode,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~Places365Dataset() override = default;
};
inline std::shared_ptr<Places365Dataset>
Places365(const std::string &dataset_dir, const std::string &usage = "train-standard", const bool small = false,
const bool decode = true, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<Places365Dataset>(StringToChar(dataset_dir), StringToChar(usage), small, decode, sampler,
cache);
}
inline std::shared_ptr<Places365Dataset> Places365(const std::string &dataset_dir, const std::string &usage,
const bool small, const bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<Places365Dataset>(StringToChar(dataset_dir), StringToChar(usage), small, decode, sampler,
cache);
}
inline std::shared_ptr<Places365Dataset> Places365(const std::string &dataset_dir, const std::string &usage,
const bool small, const bool decode,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<Places365Dataset>(StringToChar(dataset_dir), StringToChar(usage), small, decode, sampler,
cache);
}
class QMnistDataset : public Dataset {
public:
QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
QMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool compat,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~QMnistDataset() override = default;
};
inline std::shared_ptr<QMnistDataset>
QMnist(const std::string &dataset_dir, const std::string &usage = "all", bool compat = true,
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<QMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), compat, sampler, cache);
}
inline std::shared_ptr<QMnistDataset> QMnist(const std::string &dataset_dir, const std::string &usage,
bool compat, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<QMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), compat, sampler, cache);
}
inline std::shared_ptr<QMnistDataset> QMnist(const std::string &dataset_dir, const std::string &usage,
bool compat, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<QMnistDataset>(StringToChar(dataset_dir), StringToChar(usage), compat, sampler, cache);
}
inline std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1,
const std::shared_ptr<Dataset> &datasets2) {
return std::make_shared<ConcatDataset>(std::vector({datasets1, datasets2}));
}
class RandomDataDataset : public Dataset {
public:
RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
const std::vector<std::vector<char>> &columns_list, const std::shared_ptr<DatasetCache> &cache);
RandomDataDataset(const int32_t &total_rows, const std::vector<char> &schema_path,
const std::vector<std::vector<char>> &columns_list, const std::shared_ptr<DatasetCache> &cache);
~RandomDataDataset() override = default;
};
template <typename T = std::shared_ptr<SchemaObj>>
std::shared_ptr<RandomDataDataset> RandomData(const int32_t &total_rows = 0, const T &schema = nullptr,
const std::vector<std::string> &columns_list = {},
const std::shared_ptr<DatasetCache> &cache = nullptr) {
std::shared_ptr<RandomDataDataset> ds;
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
std::shared_ptr<SchemaObj> schema_obj = schema;
ds =
std::make_shared<RandomDataDataset>(total_rows, std::move(schema_obj), VectorStringToChar(columns_list), cache);
} else {
ds = std::make_shared<RandomDataDataset>(total_rows, StringToChar(schema), VectorStringToChar(columns_list), cache);
}
return ds;
}
class SBUDataset : public Dataset {
public:
SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
SBUDataset(const std::vector<char> &dataset_dir, bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
SBUDataset(const std::vector<char> &dataset_dir, bool decode, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
~SBUDataset() override = default;
};
inline std::shared_ptr<SBUDataset>
SBU(const std::string &dataset_dir, bool decode = false,
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SBUDataset>(StringToChar(dataset_dir), decode, sampler, cache);
}
inline std::shared_ptr<SBUDataset> SBU(const std::string &dataset_dir, bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SBUDataset>(StringToChar(dataset_dir), decode, sampler, cache);
}
inline std::shared_ptr<SBUDataset> SBU(const std::string &dataset_dir, bool decode,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SBUDataset>(StringToChar(dataset_dir), decode, sampler, cache);
}
class SemeionDataset : public Dataset {
public:
SemeionDataset(const std::vector<char> &dataset_dir, const ::std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
SemeionDataset(const std::vector<char> &dataset_dir, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
SemeionDataset(const std::vector<char> &dataset_dir, const ::std::reference_wrapper<Sampler> &samlper,
const std::shared_ptr<DatasetCache> &cache);
~SemeionDataset() override = default;
};
inline std::shared_ptr<SemeionDataset>
Semeion(const std::string &dataset_dir, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SemeionDataset>(StringToChar(dataset_dir), sampler, cache);
}
inline std::shared_ptr<SemeionDataset> Semeion(const std::string &dataset_dir,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SemeionDataset>(StringToChar(dataset_dir), sampler, cache);
}
inline std::shared_ptr<SemeionDataset> Semeion(const std::string &dataset_dir, Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SemeionDataset>(StringToChar(dataset_dir), sampler, cache);
}
class SogouNewsDataset : public Dataset {
public:
SogouNewsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache);
~SogouNewsDataset() override = default;
};
inline std::shared_ptr<SogouNewsDataset> SogouNews(const std::string &dataset_dir,
const std::string &usage = "all", int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SogouNewsDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
num_shards, shard_id, cache);
}
class SpeechCommandsDataset : public Dataset {
public:
SpeechCommandsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
SpeechCommandsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
SpeechCommandsDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~SpeechCommandsDataset() override = default;
};
inline std::shared_ptr<SpeechCommandsDataset>
SpeechCommands(const std::string &dataset_dir, const std::string &usage = "all",
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SpeechCommandsDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
inline std::shared_ptr<SpeechCommandsDataset>
SpeechCommands(const std::string &dataset_dir, const std::string &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SpeechCommandsDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
inline std::shared_ptr<SpeechCommandsDataset>
SpeechCommands(const std::string &dataset_dir, const std::string &usage, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SpeechCommandsDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
class STL10Dataset : public Dataset {
public:
STL10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
STL10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
STL10Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~STL10Dataset() override = default;
};
inline std::shared_ptr<STL10Dataset>
STL10(const std::string &dataset_dir, const std::string &usage = "all",
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<STL10Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
inline std::shared_ptr<STL10Dataset> STL10(const std::string &dataset_dir, const std::string &usage,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<STL10Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
inline std::shared_ptr<STL10Dataset> STL10(const std::string &dataset_dir, const std::string &usage,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<STL10Dataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
class TedliumDataset : public Dataset {
public:
TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release, const std::vector<char> &usage,
const std::vector<char> &extensions, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release, const std::vector<char> &usage,
const std::vector<char> &extensions, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release, const std::vector<char> &usage,
const std::vector<char> &extensions, const std::reference_wrapper<Sampler> &samlper,
const std::shared_ptr<DatasetCache> &cache);
~TedliumDataset() override = default;
};
inline std::shared_ptr<TedliumDataset> Tedlium(
const std::string &dataset_dir, const std::string &release, const std::string &usage = "all",
const std::string &extensions = ".sph", const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<TedliumDataset>(StringToChar(dataset_dir), StringToChar(release), StringToChar(usage),
StringToChar(extensions), sampler, cache);
}
inline std::shared_ptr<TedliumDataset> Tedlium(const std::string &dataset_dir, const std::string &release,
const std::string &usage, const std::string &extensions,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<TedliumDataset>(StringToChar(dataset_dir), StringToChar(release), StringToChar(usage),
StringToChar(extensions), sampler, cache);
}
inline std::shared_ptr<TedliumDataset> Tedlium(const std::string &dataset_dir, const std::string &release,
const std::string &usage, const std::string &extensions,
Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<TedliumDataset>(StringToChar(dataset_dir), StringToChar(release), StringToChar(usage),
StringToChar(extensions), sampler, cache);
}
class TextFileDataset : public Dataset {
public:
TextFileDataset(const std::vector<std::vector<char>> &dataset_files, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
~TextFileDataset() override = default;
};
inline std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files,
int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<TextFileDataset>(VectorStringToChar(dataset_files), num_samples, shuffle, num_shards,
shard_id, cache);
}
class TFRecordDataset : public Dataset {
public:
TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, const std::vector<char> &schema,
const std::vector<std::vector<char>> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
const std::shared_ptr<DatasetCache> &cache);
TFRecordDataset(const std::vector<std::vector<char>> &dataset_files, const std::shared_ptr<SchemaObj> &schema,
const std::vector<std::vector<char>> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
const std::shared_ptr<DatasetCache> &cache);
~TFRecordDataset() override = default;
};
template <typename T = std::shared_ptr<SchemaObj>>
std::shared_ptr<TFRecordDataset> TFRecord(const std::vector<std::string> &dataset_files,
const T &schema = nullptr,
const std::vector<std::string> &columns_list = {},
int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
bool shard_equal_rows = false,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
std::shared_ptr<TFRecordDataset> ds;
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
std::shared_ptr<SchemaObj> schema_obj = schema;
ds = std::make_shared<TFRecordDataset>(VectorStringToChar(dataset_files), std::move(schema_obj),
VectorStringToChar(columns_list), num_samples, shuffle, num_shards, shard_id,
shard_equal_rows, cache);
} else {
std::string schema_path = schema;
if (!schema_path.empty()) {
struct stat sb {};
int rc = stat(schema_path.c_str(), &sb);
if (rc != 0) {
return nullptr;
}
}
ds = std::make_shared<TFRecordDataset>(VectorStringToChar(dataset_files), StringToChar(schema_path),
VectorStringToChar(columns_list), num_samples, shuffle, num_shards, shard_id,
shard_equal_rows, cache);
}
return ds;
}
class UDPOSDataset : public Dataset {
public:
UDPOSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
~UDPOSDataset() override = default;
};
inline std::shared_ptr<UDPOSDataset> UDPOS(const std::string &dataset_dir, const std::string &usage = "all",
int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<UDPOSDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
num_shards, shard_id, cache);
}
class USPSDataset : public Dataset {
public:
USPSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
~USPSDataset() override = default;
};
inline std::shared_ptr<USPSDataset> USPS(const std::string &dataset_dir, const std::string &usage = "all",
int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<USPSDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle, num_shards,
shard_id, cache);
}
class VOCDataset : public Dataset {
public:
VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task, const std::vector<char> &usage,
const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache, bool extra_metadata);
VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task, const std::vector<char> &usage,
const std::map<std::vector<char>, int32_t> &class_indexing, bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache, bool extra_metadata);
VOCDataset(const std::vector<char> &dataset_dir, const std::vector<char> &task, const std::vector<char> &usage,
const std::map<std::vector<char>, int32_t> &class_indexing, bool decode,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache,
bool extra_metadata);
~VOCDataset() override = default;
};
inline std::shared_ptr<VOCDataset>
VOC(const std::string &dataset_dir, const std::string &task = "Segmentation", const std::string &usage = "train",
const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false,
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr, bool extra_metadata = false) {
return std::make_shared<VOCDataset>(StringToChar(dataset_dir), StringToChar(task), StringToChar(usage),
MapStringToChar(class_indexing), decode, sampler, cache, extra_metadata);
}
inline std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task,
const std::string &usage,
const std::map<std::string, int32_t> &class_indexing, bool decode,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr,
bool extra_metadata = false) {
return std::make_shared<VOCDataset>(StringToChar(dataset_dir), StringToChar(task), StringToChar(usage),
MapStringToChar(class_indexing), decode, sampler, cache, extra_metadata);
}
inline std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task,
const std::string &usage,
const std::map<std::string, int32_t> &class_indexing, bool decode,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr,
bool extra_metadata = false) {
return std::make_shared<VOCDataset>(StringToChar(dataset_dir), StringToChar(task), StringToChar(usage),
MapStringToChar(class_indexing), decode, sampler, cache, extra_metadata);
}
class WIDERFaceDataset : public Dataset {
public:
WIDERFaceDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
WIDERFaceDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
WIDERFaceDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
~WIDERFaceDataset() override = default;
};
inline std::shared_ptr<WIDERFaceDataset>
WIDERFace(const std::string &dataset_dir, const std::string &usage = "all", bool decode = false,
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<WIDERFaceDataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
}
inline std::shared_ptr<WIDERFaceDataset> WIDERFace(const std::string &dataset_dir, const std::string &usage,
bool decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<WIDERFaceDataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
}
inline std::shared_ptr<WIDERFaceDataset> WIDERFace(const std::string &dataset_dir, const std::string &usage,
bool decode, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<WIDERFaceDataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
}
class WikiTextDataset : public Dataset {
public:
WikiTextDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache);
~WikiTextDataset() override = default;
};
inline std::shared_ptr<WikiTextDataset> WikiText(const std::string &dataset_dir,
const std::string &usage = "all", int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<WikiTextDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
num_shards, shard_id, cache);
}
class YahooAnswersDataset : public Dataset {
public:
YahooAnswersDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache);
~YahooAnswersDataset() override = default;
};
inline std::shared_ptr<YahooAnswersDataset> YahooAnswers(const std::string &dataset_dir,
const std::string &usage = "all",
int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<YahooAnswersDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
num_shards, shard_id, cache);
}
class YelpReviewDataset : public Dataset {
public:
YelpReviewDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache);
~YelpReviewDataset() override = default;
};
inline std::shared_ptr<YelpReviewDataset> YelpReview(const std::string &dataset_dir,
const std::string &usage = "all", int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<YelpReviewDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
num_shards, shard_id, cache);
}
class YesNoDataset : public Dataset {
public:
YesNoDataset(const std::vector<char> &dataset_dir, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
YesNoDataset(const std::vector<char> &dataset_dir, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
YesNoDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
~YesNoDataset() override = default;
};
inline std::shared_ptr<YesNoDataset>
YesNo(const std::string &dataset_dir, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<YesNoDataset>(StringToChar(dataset_dir), sampler, cache);
}
inline std::shared_ptr<YesNoDataset> YesNo(const std::string &dataset_dir, Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<YesNoDataset>(StringToChar(dataset_dir), sampler, cache);
}
inline std::shared_ptr<YesNoDataset> YesNo(const std::string &dataset_dir,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<YesNoDataset>(StringToChar(dataset_dir), sampler, cache);
}
std::shared_ptr<DatasetCache> CreateDatasetCacheCharIF(
session_id_type id, uint64_t mem_sz, bool spill, const std::optional<std::vector<char>> &hostname = std::nullopt,
const std::optional<int32_t> &port = std::nullopt, const std::optional<int32_t> &num_connections = std::nullopt,
const std::optional<int32_t> &prefetch_sz = std::nullopt);
inline std::shared_ptr<DatasetCache> CreateDatasetCache(
session_id_type id, uint64_t mem_sz, bool spill, const std::optional<std::string> &hostname = std::nullopt,
const std::optional<int32_t> &port = std::nullopt, const std::optional<int32_t> &num_connections = std::nullopt,
const std::optional<int32_t> &prefetch_sz = std::nullopt) {
std::optional<std::vector<char>> hostname_c = std::nullopt;
if (hostname != std::nullopt) {
hostname_c = std::vector<char>(hostname->begin(), hostname->end());
}
return CreateDatasetCacheCharIF(id, mem_sz, spill, hostname_c, port, num_connections, prefetch_sz);
}
inline std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
return std::make_shared<ZipDataset>(datasets);
}
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_