Program Listing for File delegate_api.h

Return to documentation for file (include/runtime/include/api/delegate_api.h)

#ifndef MINDSPORE_INCLUDE_API_DELEGATE_API_H
#define MINDSPORE_INCLUDE_API_DELEGATE_API_H

#include <map>
#include <vector>
#include <memory>
#include "include/api/status.h"
#include "include/api/types.h"
namespace mindspore {
class AbstractDelegate {
 public:
  AbstractDelegate() = default;
  AbstractDelegate(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs)
      : inputs_(inputs), outputs_(outputs) {}
  virtual ~AbstractDelegate() = default;
  const std::vector<mindspore::MSTensor> &inputs() { return this->inputs_; }

  const std::vector<mindspore::MSTensor> &outputs() { return this->outputs_; }

 protected:
  std::vector<mindspore::MSTensor> inputs_;
  std::vector<mindspore::MSTensor> outputs_;
};

template <typename Graph, typename Node, typename Kernel>
class IDelegate : public AbstractDelegate {
 public:
  IDelegate() = default;
  IDelegate(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs)
      : AbstractDelegate(inputs, outputs) {}
  virtual ~IDelegate() = default;

  virtual void ReplaceNodes(const std::shared_ptr<Graph> &graph) = 0;

  virtual bool IsDelegateNode(const std::shared_ptr<Node> &node) = 0;

  virtual std::shared_ptr<Kernel> CreateKernel(const std::shared_ptr<Node> &node) = 0;
};
}  // namespace mindspore
#endif  // MINDSPORE_INCLUDE_API_DELEGATE_API_H