Program Listing for File register_kernel.h
↰ Return to documentation for file (include/runtime/include/registry/register_kernel.h
)
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_H_
#define MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_H_
#include <set>
#include <string>
#include <vector>
#include <memory>
#include "schema/model_generated.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/kernel.h"
#include "include/api/data_type.h"
#include "include/api/status.h"
namespace mindspore {
namespace registry {
struct KernelDesc {
DataType data_type;
int type;
std::string arch;
std::string provider;
};
struct KernelDescHelper {
DataType data_type;
int type;
std::vector<char> arch;
std::vector<char> provider;
};
using CreateKernel = std::function<std::shared_ptr<kernel::Kernel>(
const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs, const schema::Primitive *primitive,
const mindspore::Context *ctx)>;
class MS_API RegisterKernel {
public:
inline static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
const CreateKernel creator);
inline static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, const CreateKernel creator);
inline static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDesc *desc);
private:
static Status RegKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
int type, const CreateKernel creator);
static Status RegCustomKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
const std::vector<char> &type, const CreateKernel creator);
static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDescHelper *desc);
};
class MS_API KernelReg {
public:
~KernelReg() = default;
KernelReg(const std::string &arch, const std::string &provider, DataType data_type, int op_type,
const CreateKernel creator) {
(void)RegisterKernel::RegKernel(arch, provider, data_type, op_type, creator);
}
KernelReg(const std::string &arch, const std::string &provider, DataType data_type, const std::string &op_type,
const CreateKernel creator) {
(void)RegisterKernel::RegCustomKernel(arch, provider, data_type, op_type, creator);
}
};
Status RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
const CreateKernel creator) {
return RegKernel(StringToChar(arch), StringToChar(provider), data_type, type, creator);
}
Status RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, const CreateKernel creator) {
return RegCustomKernel(StringToChar(arch), StringToChar(provider), data_type, StringToChar(type), creator);
}
CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDesc *desc) {
if (desc == nullptr || primitive == nullptr) {
return nullptr;
}
KernelDescHelper kernel_desc = {desc->data_type, desc->type, StringToChar(desc->arch), StringToChar(desc->provider)};
auto ret = GetCreator(primitive, &kernel_desc);
desc->arch = CharToString(kernel_desc.arch);
return ret;
}
#define REGISTER_KERNEL(arch, provider, data_type, op_type, creator) \
namespace { \
static mindspore::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, \
op_type, creator); \
} // namespace
#define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator) \
namespace { \
static mindspore::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, \
#op_type, creator); \
} // namespace
} // namespace registry
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_H_