Program Listing for File iterator.h
↰ Return to documentation for file (include/iterator.h
)
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "include/api/dual_abi_helper.h"
#include "include/api/status.h"
#include "include/api/types.h"
namespace mindspore {
namespace dataset {
// Forward declare
class ExecutionTree;
class DatasetOp;
class Tensor;
class NativeRuntimeContext;
class IteratorConsumer;
class PullBasedIteratorConsumer;
class Dataset;
using MSTensorMap = std::unordered_map<std::string, mindspore::MSTensor>;
using MSTensorMapChar = std::map<std::vector<char>, mindspore::MSTensor>;
using MSTensorVec = std::vector<mindspore::MSTensor>;
// Abstract class for iterating over the dataset.
class Iterator {
public:
Iterator();
~Iterator();
Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs);
Status GetNextRow(MSTensorMap *row) {
if (row == nullptr) {
return Status(kMDUnexpectedError, "Got nullptr when GetNext row.");
}
MSTensorMapChar row_;
row_.clear();
row->clear();
Status s = GetNextRowCharIF(&row_);
TensorMapCharToString(&row_, row);
return s;
}
Status GetNextRowCharIF(MSTensorMapChar *row);
virtual Status GetNextRow(MSTensorVec *row);
void Stop();
class _Iterator {
public:
explicit _Iterator(Iterator *lt);
// Destructor
~_Iterator() {
if (cur_row_ != nullptr) {
delete cur_row_;
}
}
_Iterator &operator++(); // prefix ++ overload
MSTensorMap &operator*() { return *cur_row_; } // dereference operator
MSTensorMap *operator->() { return cur_row_; }
bool operator!=(const _Iterator &rhs) { return cur_row_ != rhs.cur_row_; }
private:
int ind_; // the cur node our Iterator points to
Iterator *lt_;
MSTensorMap *cur_row_;
};
_Iterator begin() { return _Iterator(this); }
_Iterator end() { return _Iterator(nullptr); }
private:
std::unique_ptr<NativeRuntimeContext> runtime_context_;
IteratorConsumer *consumer_;
};
class PullIterator : public Iterator {
public:
PullIterator();
~PullIterator() = default;
Status GetNextRow(MSTensorVec *const row) override;
Status GetRows(int32_t num_rows, std::vector<MSTensorVec> *const row);
Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds);
private:
std::unique_ptr<PullBasedIteratorConsumer> pull_consumer_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_