Program Listing for File register_kernel.h

Return to documentation for file (include/converter/include/registry/register_kernel.h)

#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_H_
#define MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_H_

#include <set>
#include <string>
#include <vector>
#include <memory>
#include "schema/model_generated.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/kernel.h"
#include "include/api/data_type.h"
#include "include/api/status.h"

namespace mindspore {
namespace registry {
struct KernelDesc {
  DataType data_type;
  int type;
  std::string arch;
  std::string provider;
};

struct KernelDescHelper {
  DataType data_type;
  int type;
  std::vector<char> arch;
  std::vector<char> provider;
};

using CreateKernel = std::function<std::shared_ptr<kernel::Kernel>(
  const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs, const schema::Primitive *primitive,
  const mindspore::Context *ctx)>;

class MS_API RegisterKernel {
 public:
  inline static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
                                 const CreateKernel creator);

  inline static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
                                       const std::string &type, const CreateKernel creator);

  inline static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDesc *desc);

 private:
  static Status RegKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
                          int type, const CreateKernel creator);
  static Status RegCustomKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
                                const std::vector<char> &type, const CreateKernel creator);
  static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDescHelper *desc);
};

class MS_API KernelReg {
 public:
  ~KernelReg() = default;

  KernelReg(const std::string &arch, const std::string &provider, DataType data_type, int op_type,
            const CreateKernel creator) {
    (void)RegisterKernel::RegKernel(arch, provider, data_type, op_type, creator);
  }

  KernelReg(const std::string &arch, const std::string &provider, DataType data_type, const std::string &op_type,
            const CreateKernel creator) {
    (void)RegisterKernel::RegCustomKernel(arch, provider, data_type, op_type, creator);
  }
};

Status RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
                                 const CreateKernel creator) {
  return RegKernel(StringToChar(arch), StringToChar(provider), data_type, type, creator);
}

Status RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
                                       const std::string &type, const CreateKernel creator) {
  return RegCustomKernel(StringToChar(arch), StringToChar(provider), data_type, StringToChar(type), creator);
}

CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDesc *desc) {
  if (desc == nullptr || primitive == nullptr) {
    return nullptr;
  }
  KernelDescHelper kernel_desc = {desc->data_type, desc->type, StringToChar(desc->arch), StringToChar(desc->provider)};
  auto ret = GetCreator(primitive, &kernel_desc);
  desc->arch = CharToString(kernel_desc.arch);
  return ret;
}

#define REGISTER_KERNEL(arch, provider, data_type, op_type, creator)                                                   \
  namespace {                                                                                                          \
  static mindspore::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, \
                                                                                          op_type, creator);           \
  }  // namespace

#define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator)                                            \
  namespace {                                                                                                          \
  static mindspore::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, \
                                                                                          #op_type, creator);          \
  }  // namespace
}  // namespace registry
}  // namespace mindspore

#endif  // MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_H_