Program Listing for File iterator.h

Return to documentation for file (include/iterator.h)

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_

#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "include/api/dual_abi_helper.h"
#include "include/api/status.h"
#include "include/api/types.h"

namespace mindspore {
namespace dataset {

// Forward declare
class ExecutionTree;
class DatasetOp;
class Tensor;

class NativeRuntimeContext;
class IteratorConsumer;
class PullBasedIteratorConsumer;

class Dataset;

using MSTensorMap = std::unordered_map<std::string, mindspore::MSTensor>;
using MSTensorMapChar = std::map<std::vector<char>, mindspore::MSTensor>;
using MSTensorVec = std::vector<mindspore::MSTensor>;

// Abstract class for iterating over the dataset.
class Iterator {
 public:
  Iterator();

  ~Iterator();

  Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs);

  Status GetNextRow(MSTensorMap *row) {
    if (row == nullptr) {
      return Status(kMDUnexpectedError, "Got nullptr when GetNext row.");
    }
    MSTensorMapChar row_;
    row_.clear();
    row->clear();
    Status s = GetNextRowCharIF(&row_);
    TensorMapCharToString(&row_, row);
    return s;
  }

  Status GetNextRowCharIF(MSTensorMapChar *row);

  virtual Status GetNextRow(MSTensorVec *row);

  void Stop();

  class _Iterator {
   public:
    explicit _Iterator(Iterator *lt);

    // Destructor
    ~_Iterator() {
      if (cur_row_ != nullptr) {
        delete cur_row_;
      }
    }

    _Iterator &operator++();                        // prefix ++ overload
    MSTensorMap &operator*() { return *cur_row_; }  // dereference operator
    MSTensorMap *operator->() { return cur_row_; }

    bool operator!=(const _Iterator &rhs) { return cur_row_ != rhs.cur_row_; }

   private:
    int ind_;  // the cur node our Iterator points to
    Iterator *lt_;
    MSTensorMap *cur_row_;
  };

  _Iterator begin() { return _Iterator(this); }

  _Iterator end() { return _Iterator(nullptr); }

 private:
  std::unique_ptr<NativeRuntimeContext> runtime_context_;
  IteratorConsumer *consumer_;
};

class PullIterator : public Iterator {
 public:
  PullIterator();

  ~PullIterator() = default;

  Status GetNextRow(MSTensorVec *const row) override;

  Status GetRows(int32_t num_rows, std::vector<MSTensorVec> *const row);

  Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds);

 private:
  std::unique_ptr<PullBasedIteratorConsumer> pull_consumer_;
};
}  // namespace dataset
}  // namespace mindspore
#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_ITERATOR_H_