离线构建自定义算子
概述
MindSpore Lite的转换工具除了基本的模型转换功能之外,还支持用户对模型进行自定义的优化与构建,生成用户自定义算子的模型。
我们提供了一套注册机制,允许用户基于转换工具进行能力扩展:包括节点解析扩展、模型解析扩展以及图优化扩展,用户可以根据自身的需要对模型实现自定义的解析与融合优化。
节点解析扩展:用户自定义模型中某一节点的解析过程,支持ONNX、CAFFE、TF、TFLITE。接口可参考NodeParser、NodeParserRegistry。 模型解析扩展:用户自定义模型的整个解析过程,支持ONNX、CAFFE、TF、TFLITE。接口可参考ModelParser、ModelParserRegistry。 图优化扩展:模型解析之后,将获得MindSpore定义的图结构,用户可基于此结构自定义图的优化过程。接口可参考PassBase、PassPosition、PassRegistry。
节点解析扩展需要依赖flatbuffers和protobuf及三方框架的序列化文件,并且flatbuffers和protobuf需要与发布件采用的版本一致,序列化文件需保证兼容发布件采用的序列化文件。发布件中不提供flatbuffers、protobuf及序列化文件,用户需自行编译,并生成序列化文件。用户可以从MindSpore仓中获取flatbuffers、probobuf、ONNX原型文件、CAFFE原型文件、TF原型文件和TFLITE原型文件。
MindSpore Lite还提供了一系列的注册宏,以便于用户侧的扩展接入转换工具。注册宏包括节点解析注册REG_NODE_PARSER、模型解析注册REG_MODEL_PARSER、图优化注册REG_PASS、图优化调度注册REG_SCHEDULED_PASS。
MindSpore Lite转换工具的扩展能力,目前仅支持Linux系统。
本章节将通过MindSpore Lite转换工具扩展功能的示例程序,涵盖节点扩展案例、优化扩展案例以及编译链接全流程,来使用户能够快速了解转换工具的扩展功能的使用。
模型解析扩展,鉴于是模块化的扩展能力,本章不做详细介绍,但会提供一个简化的单元案例,以供用户参考。
本章节以add.tflite模型为例。该模型仅包含一个简单的Add算子,通过自定义的节点解析、图优化,将Add算子转化为Custom算子,最终输出Custom单算子模型。
相关代码放置在mindspore/lite/examples/converter_extend目录。
节点扩展
自定义节点解析:用户需继承NodeParser,继而根据不同的框架,选择不同的重载接口。
节点解析注册:用户调用注册接口REG_NODE_PARSER,完成自定义的节点解析接入转换工具。
class AddParserTutorial : public NodeParser { // 继承基类
public:
AddParserTutorial() = default;
~AddParserTutorial() = default;
ops::PrimitiveC *Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, // 重载接口
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
REG_NODE_PARSER(kFmkTypeTflite, ADD, std::make_shared<AddParserTutorial>()); // 调用注册接口
示例代码请参考node_parser。
模型扩展
示例代码请参考MindSpore仓的模型扩展的单元案例ModelParserRegistryTest。
优化扩展
优化注册:调用优化的注册接口REG_PASS,完成自定义把用户自己实现的Pass类注册进MindSpore Lite里。
class PassTutorial : public registry::PassBase { // 继承基类
public:
PassTutorial() : PassBase("PassTutorial") {}
~PassTutorial() = default;
bool Execute(const api::FuncGraphPtr &func_graph) override; // 重载接口
private:
AnfNodePtr CreateCustomOp(const api::FuncGraphPtr func_graph, const CNodePtr &cnode);
};
using mindspore::registry::POSITION_BEGIN; // 选择调度位置
REG_PASS(PassTutorial, opt::PassTutorial) // 注册扩展类
REG_SCHEDULED_PASS(POSITION_BEGIN, {"PassTutorial"}) // 注册调度逻辑
示例代码可参考pass。
在离线转换阶段,我们会对模型的每一个节点的输出张量进行推断,包括输出张量的Format、DataType以及Shape,因此,离线转换阶段,用户需提供自己实现的算子的推断过程,这里用户可以参考算子Infershape扩展说明,示例代码可参考infer。
示例演示
编译
环境要求
编译准备
MindSpore Lite的发布件不会提供其他框架下的序列化文件,因此,用户需自行编译获得,请参考概述。
本示例采用的是tflite模型,用户需编译flatbuffers,从MindSpore仓中获取TFLITE原型文件,最终生成tflite的序列化文件。
在
mindspore/lite/examples/converter_extend
目录下创建schema
文件目录,继而将生成的序列化文件置于schema
目录下。编译构建
在
mindspore/lite/examples/converter_extend
目录下执行build.sh,将自动下载MindSpore Lite发布件并编译Demo。bash build.sh
若使用该build脚本下载MindSpore Lite发布件失败,请手动下载硬件平台为CPU、操作系统为Ubuntu-x64的MindSpore Lite发布件mindspore-lite-{version}-linux-x64.tar.gz,将解压后
tools/converter/lib
目录、tools/converter/include
目录拷贝到mindspore/lite/examples/converter_extend
目录下。通过手动下载并且将文件放到指定位置后,需要再次执行build.sh脚本才能完成编译构建。
编译输出
在
mindspore/lite/examples/converter_extend/build
目录下生成了libconverter_extend_tutorial.so
的动态库。
执行程序
拷贝动态库
将生成的
libconverter_extend_tutorial.so
动态库文件拷贝到发布件的tools/converter/lib
下。进入发布件的转换目录
cd ${PACKAGE_ROOT_PATH}/tools/converter/converter
创建converter的配置文件(converter.cfg,详细可参考扩展配置),文件内容如下:
[registry] plugin_path=libconverter_extend_tutorial.so # 用户请配置动态库的正确路径
将转换工具需要的动态链接库加入环境变量
LD_LIBRARY_PATH
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${PACKAGE_ROOT_PATH}/tools/converter/lib
执行converter
./converter_lite --fmk=TFLITE --modelFile=add.tflite --configFile=converter.cfg --outputFile=add_extend
执行完后,将生成名为add_extend.ms
的模型文件,文件路径由参数outputFile
决定。
扩展配置
在转换阶段,为了能够加载扩展模块,用户需要配置扩展动态库路径。扩展相关的参数有plugin_path
,disable_fusion
。参数的详细介绍如下所示:
参数 |
属性 |
功能描述 |
参数类型 |
默认值 |
取值范围 |
---|---|---|---|---|---|
plugin_path |
可选 |
第三方库加载路径 |
String |
- |
如有多个请用 |
disable_fusion |
可选 |
是否关闭融合优化 |
String |
off |
off、on |
fusion_blacklists |
可选 |
关闭指定融合算子名称 |
String |
- |
如有多个请用 |
发布件中已为用户生成好默认的配置文件(converter.cfg)。该配置文件内容如下:
[registry]
plugin_path=libconverter_extend_tutorial.so # 用户请配置动态库的正确路径
如果用户需要关闭指定算子融合优化,关闭指定名单融合配置如下所示:
[registry]
# 当参数disable_fusion=off时,可通过配置fusion_blacklists关闭指定融合优化,当参数disable_fusion=on时,关闭所有融合优化,参数fusion_blacklists不生效。
disable_fusion=off
fusion_blacklists=ConvActivationFusion,MatMulActivationFusion
融合算子名单如下所示:
序号 |
融合算子名称 |
---|---|
1 |
AddConcatActivationFusion |
2 |
SqueezeFusion |
3 |
TransposeFusion |
4 |
ReshapeReshapeFusion |
5 |
ConvBiasaddFusion |
6 |
ConvBatchNormFusion |
7 |
ConvScaleFusion |
8 |
GroupNormFusion |
9 |
TfNormFusion |
10 |
OnnxLayerNormFusion |
11 |
OnnxLayerNormFusion2 |
12 |
BatchMatMulFusion |
13 |
BatchNormToScaleFusion |
14 |
SigmoidMulFusion |
15 |
ActivationFusion |
16 |
ConvActivationFusion |
17 |
ConvTupleGetItemFusion |
18 |
ConvTupleActivationFusion |
19 |
TfliteLstmCellFusion |
20 |
TfLstmCellFusion |
21 |
TfBidirectionGruFusion |
22 |
TfGeLUFusion |
23 |
OnnxGeLUFusion |
24 |
TfliteRelPosMultiHeadAttentionFusion |
25 |
GLUFusion |
26 |
ConstFoldPass |
27 |
AffineFusion |
28 |
AffineActivationFusion |
29 |
ConvConvFusion |
30 |
ConvPadFusion |
31 |
MatMulAddFusion |
32 |
MatMulMulFusion |
33 |
TransposeMatMulFusion |
34 |
MulAddFusion |
35 |
ScaleActivationFusion |
36 |
ScaleScaleFusion |
37 |
FullConnectedFusion |
38 |
FullconnectedAddFusion |
39 |
TensorDotFusion |
40 |
MatMulActivationFusion |