Program Listing for File delegate.h

Return to documentation for file (include/runtime/include/api/delegate.h)

#ifndef MINDSPORE_INCLUDE_API_DELEGATE_H
#define MINDSPORE_INCLUDE_API_DELEGATE_H

#include <map>
#include <vector>
#include <memory>
#include <string>
#include <utility>
#include "schema/model_generated.h"
#include "include/api/kernel.h"
#include "include/api/delegate_api.h"

namespace mindspore {
typedef enum {
  SCHEMA_INVALID = -1,
  SCHEMA_CUR,
  SCHEMA_V0,
} SchemaVersion;

using KernelIter = std::vector<kernel::Kernel *>::iterator;

template <class T>
class MS_API DelegateModel {
 public:
  DelegateModel() = default;
  DelegateModel(std::vector<kernel::Kernel *> *kernels, const std::vector<MSTensor> &inputs,
                const std::vector<MSTensor> &outputs, const std::map<kernel::Kernel *, const T *> &primitives,
                SchemaVersion version)
      : kernels_(kernels), inputs_(inputs), outputs_(outputs), primitives_(primitives), version_(version) {}

  ~DelegateModel() = default;

  const T *GetPrimitive(kernel::Kernel *kernel) const {
    if (primitives_.find(kernel) != primitives_.end()) {
      return primitives_.at(kernel);
    } else {
      return nullptr;
    }
  }

  KernelIter BeginKernelIterator() { return kernels_->begin(); }

  KernelIter EndKernelIterator() { return kernels_->end(); }

  KernelIter Replace(KernelIter from, KernelIter end, kernel::Kernel *graph_kernel) {
    size_t insert_index = from - BeginKernelIterator();
    if (insert_index >= kernels_->size()) {
      return BeginKernelIterator();
    }
    kernels_->erase(from, end);
    kernels_->insert(BeginKernelIterator() + insert_index, graph_kernel);
    return BeginKernelIterator() + insert_index + 1;
  }

  std::vector<kernel::Kernel *> *nodes() { return kernels_; }

  const std::vector<mindspore::MSTensor> &inputs() { return this->inputs_; }

  const std::vector<mindspore::MSTensor> &outputs() { return this->outputs_; }

  SchemaVersion GetVersion() const { return version_; }

 protected:
  std::vector<kernel::Kernel *> *kernels_;
  const std::vector<mindspore::MSTensor> &inputs_;
  const std::vector<mindspore::MSTensor> &outputs_;
  const std::map<kernel::Kernel *, const T *> &primitives_;
  SchemaVersion version_;
};

// lite delegate use kernel::Kernel as graph node.
using LiteDelegateGraph = DelegateModel<schema::Primitive>;
class Delegate : public IDelegate<LiteDelegateGraph, kernel::Kernel, kernel::Kernel> {
 public:
  Delegate() = default;
  Delegate(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs)
      : IDelegate<LiteDelegateGraph, kernel::Kernel, kernel::Kernel>(inputs, outputs) {}
  virtual ~Delegate() = default;
  virtual Status Init() = 0;

  std::shared_ptr<kernel::Kernel> CreateKernel(const std::shared_ptr<kernel::Kernel> &node) override {
    // return node as kernel since they are same one.
    return node;
  }

  bool IsDelegateNode(const std::shared_ptr<kernel::Kernel> &node) override { return false; }

  void ReplaceNodes(const std::shared_ptr<LiteDelegateGraph> &graph) override {}

  virtual Status Build(LiteDelegateGraph *model) = 0;
};

class MS_API CoreMLDelegate : public Delegate {
 public:
  CoreMLDelegate();

  Status Init() override;

  Status Build(LiteDelegateGraph *model) override;

 protected:
  std::shared_ptr<Delegate> impl_;
};
}  // namespace mindspore
#endif  // MINDSPORE_INCLUDE_API_DELEGATE_H