Program Listing for File delegate.h
↰ Return to documentation for file (include/delegate.h
)
#ifndef MINDSPORE_INCLUDE_API_DELEGATE_H
#define MINDSPORE_INCLUDE_API_DELEGATE_H
#include <map>
#include <vector>
#include <memory>
#include "schema/model_generated.h"
#include "include/api/kernel.h"
#include "include/api/status.h"
namespace mindspore {
typedef enum {
SCHEMA_INVALID = -1,
SCHEMA_CUR,
SCHEMA_V0,
} SchemaVersion;
using KernelIter = std::vector<kernel::Kernel *>::iterator;
template <class T>
class DelegateModel {
public:
DelegateModel(std::vector<kernel::Kernel *> *kernels, const std::vector<MSTensor> &inputs,
const std::vector<MSTensor> &outputs, const std::map<kernel::Kernel *, const T *> &primitives,
SchemaVersion version)
: kernels_(kernels), inputs_(inputs), outputs_(outputs), primitives_(primitives), version_(version) {}
~DelegateModel() = default;
const T *GetPrimitive(kernel::Kernel *kernel) const {
if (primitives_.find(kernel) != primitives_.end()) {
return primitives_.at(kernel);
} else {
return nullptr;
}
}
KernelIter BeginKernelIterator() { return kernels_->begin(); }
KernelIter EndKernelIterator() { return kernels_->end(); }
KernelIter Replace(KernelIter from, KernelIter end, kernel::Kernel *graph_kernel) {
size_t insert_index = from - BeginKernelIterator();
if (insert_index >= kernels_->size()) {
return BeginKernelIterator();
}
kernels_->erase(from, end);
kernels_->insert(BeginKernelIterator() + insert_index, graph_kernel);
return BeginKernelIterator() + insert_index + 1;
}
const std::vector<mindspore::MSTensor> &inputs() { return this->inputs_; }
const std::vector<mindspore::MSTensor> &outputs() { return this->outputs_; }
const SchemaVersion GetVersion() { return version_; }
protected:
std::vector<kernel::Kernel *> *kernels_;
const std::vector<mindspore::MSTensor> &inputs_;
const std::vector<mindspore::MSTensor> &outputs_;
const std::map<kernel::Kernel *, const T *> &primitives_;
SchemaVersion version_;
};
class Delegate {
public:
Delegate() = default;
virtual ~Delegate() = default;
virtual Status Init() = 0;
virtual Status Build(DelegateModel<schema::Primitive> *model) = 0;
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_DELEGATE_H