Program Listing for File delegate.h
↰ Return to documentation for file (include/runtime/include/api/delegate.h
)
#ifndef MINDSPORE_INCLUDE_API_DELEGATE_H
#define MINDSPORE_INCLUDE_API_DELEGATE_H
#include <map>
#include <vector>
#include <memory>
#include <string>
#include <utility>
#include "schema/model_generated.h"
#include "include/api/kernel.h"
#include "include/api/delegate_api.h"
namespace mindspore {
typedef enum {
SCHEMA_INVALID = -1,
SCHEMA_CUR,
SCHEMA_V0,
} SchemaVersion;
using KernelIter = std::vector<kernel::Kernel *>::iterator;
template <class T>
class MS_API DelegateModel {
public:
DelegateModel() = default;
DelegateModel(std::vector<kernel::Kernel *> *kernels, const std::vector<MSTensor> &inputs,
const std::vector<MSTensor> &outputs, const std::map<kernel::Kernel *, const T *> &primitives,
SchemaVersion version)
: kernels_(kernels), inputs_(inputs), outputs_(outputs), primitives_(primitives), version_(version) {}
~DelegateModel() = default;
const T *GetPrimitive(kernel::Kernel *kernel) const {
if (primitives_.find(kernel) != primitives_.end()) {
return primitives_.at(kernel);
} else {
return nullptr;
}
}
KernelIter BeginKernelIterator() { return kernels_->begin(); }
KernelIter EndKernelIterator() { return kernels_->end(); }
KernelIter Replace(KernelIter from, KernelIter end, kernel::Kernel *graph_kernel) {
size_t insert_index = from - BeginKernelIterator();
if (insert_index >= kernels_->size()) {
return BeginKernelIterator();
}
kernels_->erase(from, end);
kernels_->insert(BeginKernelIterator() + insert_index, graph_kernel);
return BeginKernelIterator() + insert_index + 1;
}
std::vector<kernel::Kernel *> *nodes() { return kernels_; }
const std::vector<mindspore::MSTensor> &inputs() { return this->inputs_; }
const std::vector<mindspore::MSTensor> &outputs() { return this->outputs_; }
SchemaVersion GetVersion() const { return version_; }
protected:
std::vector<kernel::Kernel *> *kernels_;
const std::vector<mindspore::MSTensor> &inputs_;
const std::vector<mindspore::MSTensor> &outputs_;
const std::map<kernel::Kernel *, const T *> &primitives_;
SchemaVersion version_;
};
// lite delegate use kernel::Kernel as graph node.
using LiteDelegateGraph = DelegateModel<schema::Primitive>;
class Delegate : public IDelegate<LiteDelegateGraph, kernel::Kernel, kernel::Kernel> {
public:
Delegate() = default;
Delegate(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs)
: IDelegate<LiteDelegateGraph, kernel::Kernel, kernel::Kernel>(inputs, outputs) {}
virtual ~Delegate() = default;
virtual Status Init() = 0;
std::shared_ptr<kernel::Kernel> CreateKernel(const std::shared_ptr<kernel::Kernel> &node) override {
// return node as kernel since they are same one.
return node;
}
bool IsDelegateNode(const std::shared_ptr<kernel::Kernel> &node) override { return false; }
void ReplaceNodes(const std::shared_ptr<LiteDelegateGraph> &graph) override {}
virtual Status Build(LiteDelegateGraph *model) = 0;
};
class MS_API CoreMLDelegate : public Delegate {
public:
CoreMLDelegate();
Status Init() override;
Status Build(LiteDelegateGraph *model) override;
protected:
std::shared_ptr<Delegate> impl_;
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_DELEGATE_H