Program Listing for File register_kernel_interface.h

Return to documentation for file (include/converter/include/registry/register_kernel_interface.h)

#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_INTERFACE_H_
#define MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_INTERFACE_H_

#include <set>
#include <string>
#include <vector>
#include <memory>
#include "include/kernel_interface.h"
#include "schema/model_generated.h"

namespace mindspore {
namespace kernel {
class Kernel;
}
namespace registry {
using KernelInterfaceCreator = std::function<std::shared_ptr<kernel::KernelInterface>()>;

class MS_API RegisterKernelInterface {
 public:
  inline static Status CustomReg(const std::string &provider, const std::string &op_type,
                                 const KernelInterfaceCreator creator);

  inline static Status Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator);

  inline static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
                                                                            const schema::Primitive *primitive,
                                                                            const kernel::Kernel *kernel = nullptr);

 private:
  static Status CustomReg(const std::vector<char> &provider, const std::vector<char> &op_type,
                          const KernelInterfaceCreator creator);
  static Status Reg(const std::vector<char> &provider, int op_type, const KernelInterfaceCreator creator);
  static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::vector<char> &provider,
                                                                     const schema::Primitive *primitive,
                                                                     const kernel::Kernel *kernel = nullptr);
};

class MS_API KernelInterfaceReg {
 public:
  KernelInterfaceReg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
    (void)RegisterKernelInterface::Reg(provider, op_type, creator);
  }

  KernelInterfaceReg(const std::string &provider, const std::string &op_type, const KernelInterfaceCreator creator) {
    (void)RegisterKernelInterface::CustomReg(provider, op_type, creator);
  }

  virtual ~KernelInterfaceReg() = default;
};

Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
                                          const KernelInterfaceCreator creator) {
  return CustomReg(StringToChar(provider), StringToChar(op_type), creator);
}

Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
  return Reg(StringToChar(provider), op_type, creator);
}

std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(const std::string &provider,
                                                                                     const schema::Primitive *primitive,
                                                                                     const kernel::Kernel *kernel) {
  return GetKernelInterface(StringToChar(provider), primitive, kernel);
}

#define REGISTER_KERNEL_INTERFACE(provider, op_type, creator)                                                    \
  namespace {                                                                                                    \
  static mindspore::registry::KernelInterfaceReg g_##provider##op_type##_inter_reg(#provider, op_type, creator); \
  }  // namespace

#define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator)                                           \
  namespace {                                                                                                  \
  static mindspore::registry::KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(#provider, #op_type, \
                                                                                          creator);            \
  }  // namespace
}  // namespace registry
}  // namespace mindspore

#endif  // MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_INTERFACE_H_