mindspore::registry
接口汇总
类名 |
描述 |
---|---|
扩展Node解析的注册类。 |
|
注册扩展Node解析。 |
|
扩展Model解析的注册类。 |
|
注册扩展Model解析。 |
|
Pass的基类。 |
|
扩展Pass的运行位置。 |
|
扩展Pass注册构造类。 |
|
注册扩展Pass。 |
|
注册扩展Pass的调度顺序。 |
|
算子注册实现类。 |
|
算子注册构造类。 |
|
注册算子。 |
|
注册Custom算子注册。 |
|
算子扩展能力注册实现类。 |
|
算子扩展能力注册构造类。 |
|
注册算子扩展能力。 |
|
注册Custom算子扩展能力。 |
NodeParserRegistry
#include <node_parser_registry.h>
NodeParserRegistry类用于注册及获取NodeParser类型的共享智能指针。
NodeParserRegistry
NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type,
const converter::NodeParserPtr &node_parser);
构造函数。
参数
fmk_type
: 框架类型,具体见FmkType说明。node_type
: 节点的类型。node_parser
: NodeParser类型的共享智能指针实例, 具体见NodeParserPtr说明。
~NodeParserRegistry
~NodeParserRegistry = default;
析构函数。
公有成员函数
GetNodeParser
static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const std::string &node_type);
静态方法,获取NodeParser类型的共享智能指针实例。
参数
fmk_type
: 框架类型,具体见FmkType说明。node_type
: 节点的类型。
REG_NODE_PARSER
#include <node_parser_registry.h>
#define REG_NODE_PARSER(fmk_type, node_type, node_parser)
注册NodeParser宏。
参数
fmk_type
: 框架类型,具体见FmkType说明。node_type
: 节点的类型。node_parser
: NodeParser类型的共享智能指针实例, 具体见NodeParserPtr说明。
ModelParserCreator
#include <model_parser_registry.h>
typedef converter::ModelParser *(*ModelParserCreator)()
创建ModelParser的函数原型声明。
ModelParserRegistry
#include <model_parser_registry.h>
ModelParserRegistry类用于注册及获取ModelParserCreator类型的函数指针。
ModelParserRegistry
ModelParserRegistry(FmkType fmk, ModelParserCreator creator)
构造函数,构造ModelParserRegistry对象,进行Model解析注册。
参数
fmk
: 框架类型,具体见FmkType说明。creator
: ModelParserCreator类型的函数指针, 具体见ModelParserCreator说明。
~ModelParserRegistry
~ModelParserRegistry()
析构函数。
公有成员函数
GetModelParser
static ModelParser *GetModelParser(FmkType fmk)
获取ModelParserCreator类型的函数指针。
参数
fmk
: 框架类型,具体见FmkType说明。
REG_MODEL_PARSER
#include <model_parser_registry.h>
#define REG_MODEL_PARSER(fmk, parserCreator)
注册ModelParserCreator类。
参数
fmk
: 框架类型,具体见FmkType说明。creator
: ModelParserCreator类型的函数指针, 具体见ModelParserCreator说明。
用户自定义的ModelParser,框架类型必须满足设定支持的框架类型FmkType。
PassBase
#include <pass_base.h>
PassBase定义了图优化的基类,以供用户继承并自定义图优化算法。
PassBase
PassBase(const std::string &name = "PassBase")
构造函数,构造PassBase类对象。
参数
name
: PassBase类对象的标识,需保证唯一性。
~PassBase
virtual ~PassBase() = default;
析构函数。
公有成员函数
Execute
virtual bool Execute(const api::FuncGraphPtr &func_graph) = 0;
对图进行操作的接口函数。
参数
func_graph
: FuncGraph的指针类对象。
PassBasePtr
#include <pass_base.h>
PassBase类的共享智能指针类型。
using PassBasePtr = std::shared_ptr<PassBase>
PassPosition
#include <pass_registry.h>
enum类型变量,定义扩展Pass的运行位置。
enum PassPosition {
POSITION_BEGIN = 0, // 扩展Pass运行于内置融合Pass前
POSITION_END = 1 // 扩展Pass运行于内置融合Pass后
};
PassRegistry
#include <pass_registry.h>
PassRegistry类用于注册及获取Pass类实例。
PassRegistry
PassRegistry(const std::string &pass_name, const PassBasePtr &pass)
构造函数,构造PassRegistry对象,进行注册Pass。
参数
pass_name
: Pass的命名标识,保证唯一性。pass
: PassBase类实例。
PassRegistry(PassPosition position, const std::vector<std::string> &names)
构造函数,构造PassRegistry对象,指定扩展Pass的运行位置及其运行顺序。
参数
position
: 扩展Pass的运行位置,具体见PassPosition说明。names
: 用户指定在该运行位置处,调用Pass的命名标识,命名标识的顺序即为指定Pass的调用顺序。
~PassRegistry
~PassRegistry()
析构函数。
公有成员函数
GetOuterScheduleTask
static std::vector<std::string> GetOuterScheduleTask(PassPosition position)
获取指定位置处,外部设定的调度任务。
参数
position
: 扩展Pass的运行位置,具体见PassPosition说明。
GetPassFromStoreRoom
static PassBasePtr GetPassFromStoreRoom(const std::string &pass_name)
获取PassBase实例,根据指定的Pass命名标识。
参数
pass_name
: Pass的命名标识。
REG_PASS
#include <pass_registry.h>
#define REG_PASS(name, pass)
注册Pass宏。
参数
name
: Pass的命名标识,保证唯一性。pass
: PassBase类实例。
REG_SCHEDULED_PASS
#include <pass_registry.h>
#define REG_SCHEDULED_PASS(position, names)
指定扩展Pass的运行位置及其运行顺序。
参数
position
: 扩展Pass的运行位置,具体见PassPosition说明。names
: 用户指定在该运行位置处,调用Pass的命名标识,命名标识的顺序即为指定Pass的调用顺序。
MindSpore Lite开放了部分内置Pass,请见以下说明。用户可以在
names
参数中添加内置Pass的命名标识,以在指定运行处调用内置Pass。
ConstFoldPass
: 将输入均是常量的节点进行离线计算,导出的模型将不含该节点。特别地,针对shape算子,在inputShape给定的情形下,也会触发预计算。
DumpGraph
: 导出当前状态下的模型。请确保当前模型为NHWC或者NCHW格式的模型,例如卷积算子等。
ToNCHWFormat
: 将当前状态下的模型转换为NCHW的格式,例如,四维的图输入、卷积算子等。
ToNHWCFormat
: 将当前状态下的模型转换为NHWC的格式,例如,四维的图输入、卷积算子等。
DecreaseTransposeAlgo
: transpose算子的优化算法,删除冗余的transpose算子。
ToNCHWFormat
与ToNHWCFormat
需配套使用。在开放的运行位置处,用户所得到的模型已统一为NHWC的格式,用户也需确保在当前运行位置处返回之时,模型也是NHWC的格式。例: 指定names为{“ToNCHWFormat”, “UserPass”,”ToNHWCFormat”}。
KernelDesc
#include <registry/register_kernel.h>
struct类型结构体,定义扩展kernel的基本属性。
struct KernelDesc {
DataType data_type; // kernel的计算数据类型
int type; // 算子的类型
std::string arch; // 设备标识
std::string provider; // 用户标识
};
RegisterKernel
#include <registry/register_kernel.h>
CreateKernel
using CreateKernel = std::function<std::shared_ptr<kernel::Kernel>(
const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs, const schema::Primitive *primitive,
const mindspore::Context *ctx)>
创建算子的函数原型声明。
公有成员函数
RegKernel
static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type, const CreateKernel creator)
算子注册。
参数
arch
: 算子运行的平台,由用户自定义,如果算子是运行在CPU平台,或者算子运行完后的output tensor里的内存是在CPU平台上的,则此处也写CPU,MindSpore Lite内部会切成一个子图,在异构并行场景下有助于性能提升。provider
: 生产商名,由用户自定义。data_type
: 算子支持的数据类型,具体见DataType。op_type
: 算子类型,定义在ops.fbs中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。creator
: 创建算子的函数指针,具体见CreateKernel的说明。
RegCustomKernel
static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, const std::string &type, const CreateKernel creator)
Custom算子注册。
参数
arch
: 算子运行的平台,由用户自定义,如果算子是运行在CPU平台,或者算子运行完后的output tensor里的内存是在CPU平台上的,则此处也写CPU,MindSpore Lite内部会切成一个子图,在异构并行场景下有助于性能提升。provider
: 生产商名,由用户自定义。data_type
: 算子支持的数据类型,具体见DataType。type
: 算子类型,由用户自定义,确保唯一即可。creator
: 创建算子的函数指针,具体见CreateKernel的说明。
GetCreator
static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDesc *desc);
获取算子的创建函数。
参数
primitive
: 算子经由flatbuffers反序化为Primitive后的结果。desc
: 算子的基本属性,具体见KernelDesc说明。
KernelReg
#include <registry/register_kernel.h>
~KernelReg
~KernelReg() = default
析构函数。
KernelReg
KernelReg(const std::string &arch, const std::string &provider, DataType data_type, int op_type, const CreateKernel creator)
构造函数,构造注册算子,进行算子注册。
参数
arch
: 算子运行的平台,由用户自定义,如果算子是运行在CPU平台,或者算子运行完后的output tensor里的内存是在CPU平台上的,则此处也写CPU,MindSpore Lite内部会切成一个子图,在异构并行场景下有助于性能提升。provider
: 生产商名,由用户自定义。data_type
: 算子支持的数据类型,具体见DataType。op_type
: 算子类型,定义在ops.fbs中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。creator
: 创建算子的函数指针,具体见CreateKernel的说明。
KernelReg(const std::string &arch, const std::string &provider, DataType data_type, const std::string &op_type, const CreateKernel creator)
构造函数,构造注册Custom算子,进行算子注册。
参数
arch
: 算子运行的平台,由用户自定义,如果算子是运行在CPU平台,或者算子运行完后的output tensor里的内存是在CPU平台上的,则此处也写CPU,MindSpore Lite内部会切成一个子图,在异构并行场景下有助于性能提升。provider
: 生产商名,由用户自定义。data_type
: 算子支持的数据类型,具体见DataType。op_type
: 算子类型,由用户自定义,确保唯一即可。creator
: 创建算子的函数指针,具体见CreateKernel的说明。
REGISTER_KERNEL
#define REGISTER_KERNEL(arch, provider, data_type, op_type, creator)
注册算子宏。
参数
arch
: 算子运行的平台,由用户自定义,如果算子是运行在CPU平台,或者算子运行完后的output tensor里的内存是在CPU平台上的,则此处也写CPU,MindSpore Lite内部会切成一个子图,在异构并行场景下有助于性能提升。provider
: 生产商名,由用户自定义。data_type
: 算子支持的数据类型,具体见DataType。op_type
: 算子类型,定义在ops.fbs中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。creator
: 创建算子的函数指针,具体见CreateKernel的说明。
REGISTER_CUSTOM_KERNEL
#define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator)
注册Custom算子。
参数
arch
: 算子运行的平台,由用户自定义,如果算子是运行在CPU平台,或者算子运行完后的output tensor里的内存是在CPU平台上的,则此处也写CPU,MindSpore Lite内部会切成一个子图,在异构并行场景下有助于性能提升。provider
: 生产商名,由用户自定义。data_type
: 算子支持的数据类型,具体见DataType。op_type
: 算子类型,由用户自定义,确保唯一即可。creator
: 创建算子的函数指针,具体见CreateKernel的说明。
KernelInterfaceCreator
#include <registry/register_kernel_interface.h>
定义创建算子的函数指针类型。
using KernelInterfaceCreator = std::function<std::shared_ptr<kernel::KernelInterface>()>;
RegisterKernelInterface
#include <registry/register_kernel_interface.h>
算子扩展能力注册实现类。
公有成员函数
CustomReg
static Status CustomReg(const std::string &provider, const std::string &op_type, const KernelInterfaceCreator creator)
Custom算子的扩展能力注册。
参数
provider
: 生产商,由用户自定义。op_type
: 算子类型,由用户自定义。creator
: KernelInterface的创建函数,详细见KernelInterfaceCreator的说明。
Reg
static Status Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator)
算子的扩展能力注册。
参数
provider
: 生产商,由用户自定义。op_type
: 算子类型,定义在ops.fbs中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。creator
: KernelInterface的创建函数,详细见KernelInterfaceCreator的说明。
GetKernelInterface
static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider, const schema::Primitive *primitive, const kernel::Kernel *kernel)
获取注册的算子扩展能力。
参数
provider
:生产商名,由用户自定义。primitive
:算子经过flatbuffers反序化后的结果,存储算子属性。kernel
:算子的内核,不传的话默认为空,为空时必须保证primitive非空有效。
KernelInterfaceReg
#include <registry/register_kernel_interface.h>
算子扩展能力注册构造类。
KernelInterfaceReg
KernelInterfaceReg(const std::string &provider, int op_type, const KernelInterfaceCreator creator)
构造函数,构造注册算子的扩展能力。
参数
provider
: 生产商,由用户自定义。op_type
: 算子类型,定义在ops.fbs中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。creator
: KernelInterface的创建函数,详细见KernelInterfaceCreator的说明。
KernelInterfaceReg(const std::string &provider, const std::string &op_type, const KernelInterfaceCreator creator)
构造函数,构造注册custom算子的扩展能力。
参数
provider
: 生产商,由用户自定义。op_type
: 算子类型,由用户自定义。creator
: KernelInterface的创建函数,详细见KernelInterfaceCreator的说明。
REGISTER_KERNEL_INTERFACE
#include <registry/register_kernel_interface.h>
注册KernelInterface的实现。
#define REGISTER_KERNEL_INTERFACE(provider, op_type, creator)
参数
provider
: 生产商,由用户自定义。op_type
: 算子类型,定义在ops.fbs中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。creator
: 创建KernelInterface的函数指针,具体见KernelInterfaceCreator的说明。
REGISTER_CUSTOM_KERNEL_INTERFACE
#include <registry/register_kernel_interface.h>
注册Custom算子对应的KernelInterface实现。
#define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator)
参数
provider
: 生产商名,由用户自定义。op_type
: 算子类型,由用户自定义,确保唯一同时要与REGISTER_CUSTOM_KERNEL时注册的op_type保持一致。creator
: 创建算子的函数指针,具体见KernelInterfaceCreator的说明。