Program Listing for File register_kernel_interface.h
↰ Return to documentation for file (include/converter/include/registry/register_kernel_interface.h
)
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_INTERFACE_H_
#define MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_INTERFACE_H_
#include <set>
#include <string>
#include <vector>
#include <memory>
#include "include/kernel_interface.h"
#include "schema/model_generated.h"
namespace mindspore {
namespace kernel {
class Kernel;
}
namespace registry {
using KernelInterfaceCreator = std::function<std::shared_ptr<kernel::KernelInterface>()>;
class MS_API RegisterKernelInterface {
public:
inline static Status CustomReg(const std::string &provider, const std::string &op_type,
const KernelInterfaceCreator creator);
inline static Status Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator);
inline static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive,
const kernel::Kernel *kernel = nullptr);
private:
static Status CustomReg(const std::vector<char> &provider, const std::vector<char> &op_type,
const KernelInterfaceCreator creator);
static Status Reg(const std::vector<char> &provider, int op_type, const KernelInterfaceCreator creator);
static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::vector<char> &provider,
const schema::Primitive *primitive,
const kernel::Kernel *kernel = nullptr);
};
class MS_API KernelInterfaceReg {
public:
KernelInterfaceReg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
(void)RegisterKernelInterface::Reg(provider, op_type, creator);
}
KernelInterfaceReg(const std::string &provider, const std::string &op_type, const KernelInterfaceCreator creator) {
(void)RegisterKernelInterface::CustomReg(provider, op_type, creator);
}
virtual ~KernelInterfaceReg() = default;
};
Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
const KernelInterfaceCreator creator) {
return CustomReg(StringToChar(provider), StringToChar(op_type), creator);
}
Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
return Reg(StringToChar(provider), op_type, creator);
}
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive,
const kernel::Kernel *kernel) {
return GetKernelInterface(StringToChar(provider), primitive, kernel);
}
#define REGISTER_KERNEL_INTERFACE(provider, op_type, creator) \
namespace { \
static mindspore::registry::KernelInterfaceReg g_##provider##op_type##_inter_reg(#provider, op_type, creator); \
} // namespace
#define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator) \
namespace { \
static mindspore::registry::KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(#provider, #op_type, \
creator); \
} // namespace
} // namespace registry
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_INTERFACE_H_