Program Listing for File delegate.h

Return to documentation for file (include/delegate.h)

#ifndef MINDSPORE_INCLUDE_API_DELEGATE_H
#define MINDSPORE_INCLUDE_API_DELEGATE_H

#include <map>
#include <vector>
#include <memory>
#include "schema/model_generated.h"
#include "include/api/kernel.h"
#include "include/api/status.h"

namespace mindspore {
typedef enum {
  SCHEMA_INVALID = -1,
  SCHEMA_CUR,
  SCHEMA_V0,
} SchemaVersion;

using KernelIter = std::vector<kernel::Kernel *>::iterator;

template <class T>
class  DelegateModel {
 public:
  DelegateModel(std::vector<kernel::Kernel *> *kernels, const std::vector<MSTensor> &inputs,
                const std::vector<MSTensor> &outputs, const std::map<kernel::Kernel *, const T *> &primitives,
                SchemaVersion version)
      : kernels_(kernels), inputs_(inputs), outputs_(outputs), primitives_(primitives), version_(version) {}

  ~DelegateModel() = default;

  const T *GetPrimitive(kernel::Kernel *kernel) const {
    if (primitives_.find(kernel) != primitives_.end()) {
      return primitives_.at(kernel);
    } else {
      return nullptr;
    }
  }

  KernelIter BeginKernelIterator() { return kernels_->begin(); }

  KernelIter EndKernelIterator() { return kernels_->end(); }

  KernelIter Replace(KernelIter from, KernelIter end, kernel::Kernel *graph_kernel) {
    size_t insert_index = from - BeginKernelIterator();
    if (insert_index >= kernels_->size()) {
      return BeginKernelIterator();
    }
    kernels_->erase(from, end);
    kernels_->insert(BeginKernelIterator() + insert_index, graph_kernel);
    return BeginKernelIterator() + insert_index + 1;
  }

  const std::vector<mindspore::MSTensor> &inputs() { return this->inputs_; }

  const std::vector<mindspore::MSTensor> &outputs() { return this->outputs_; }

  const SchemaVersion GetVersion() { return version_; }

 protected:
  std::vector<kernel::Kernel *> *kernels_;
  const std::vector<mindspore::MSTensor> &inputs_;
  const std::vector<mindspore::MSTensor> &outputs_;
  const std::map<kernel::Kernel *, const T *> &primitives_;
  SchemaVersion version_;
};

class  Delegate {
 public:
  Delegate() = default;

  virtual ~Delegate() = default;

  virtual Status Init() = 0;

  virtual Status Build(DelegateModel<schema::Primitive> *model) = 0;
};
}  // namespace mindspore
#endif  // MINDSPORE_INCLUDE_API_DELEGATE_H