离线构建自定义算子

查看源文件

概述

MindSpore Lite的转换工具除了基本的模型转换功能之外,还支持用户对模型进行自定义的优化与构建,生成用户自定义算子的模型。

我们提供了一套注册机制,允许用户基于转换工具进行能力扩展:包括节点解析扩展、模型解析扩展以及图优化扩展,用户可以根据自身的需要对模型实现自定义的解析与融合优化。

节点解析扩展:用户自定义模型中某一节点的解析过程,支持ONNX、CAFFE、TF、TFLITE。接口可参考NodeParserNodeParserRegistry。 模型解析扩展:用户自定义模型的整个解析过程,支持ONNX、CAFFE、TF、TFLITE。接口可参考ModelParserModelParserRegistry。 图优化扩展:模型解析之后,将获得MindSpore定义的图结构,用户可基于此结构自定义图的优化过程。接口可参考PassBasePassPositionPassRegistry

节点解析扩展需要依赖flatbuffers和protobuf及三方框架的序列化文件,并且flatbuffers和protobuf需要与发布件采用的版本一致,序列化文件需保证兼容发布件采用的序列化文件。发布件中不提供flatbuffers、protobuf及序列化文件,用户需自行编译,并生成序列化文件。用户可以从MindSpore仓中获取flatbuffersprobobufONNX原型文件CAFFE原型文件TF原型文件TFLITE原型文件

MindSpore Lite还提供了一系列的注册宏,以便于用户侧的扩展接入转换工具。注册宏包括节点解析注册REG_NODE_PARSER、模型解析注册REG_MODEL_PARSER、图优化注册REG_PASS、图优化调度注册REG_SCHEDULED_PASS

MindSpore Lite转换工具的扩展能力,目前仅支持Linux系统。

本章节将通过MindSpore Lite转换工具扩展功能的示例程序,涵盖节点扩展案例、优化扩展案例以及编译链接全流程,来使用户能够快速了解转换工具的扩展功能的使用。

模型解析扩展,鉴于是模块化的扩展能力,本章不做详细介绍,但会提供一个简化的单元案例,以供用户参考。

本章节以add.tflite模型为例。该模型仅包含一个简单的Add算子,通过自定义的节点解析、图优化,将Add算子转化为Custom算子,最终输出Custom单算子模型。

相关代码放置在mindspore/lite/examples/converter_extend目录。

节点扩展

  1. 自定义节点解析:用户需继承NodeParser,继而根据不同的框架,选择不同的重载接口。

  2. 节点解析注册:用户调用注册接口REG_NODE_PARSER,完成自定义的节点解析接入转换工具。

class AddParserTutorial : public NodeParser {  // 继承基类
 public:
  AddParserTutorial() = default;
  ~AddParserTutorial() = default;
  ops::PrimitiveC *Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,            // 重载接口
                         const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
                         const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};

REG_NODE_PARSER(kFmkTypeTflite, ADD, std::make_shared<AddParserTutorial>());     // 调用注册接口

示例代码请参考node_parser

模型扩展

示例代码请参考MindSpore仓的模型扩展的单元案例ModelParserRegistryTest

优化扩展

  1. 自定义优化:用户需继承PassBase,重载Execute接口函数Execute

  2. 优化注册:调用优化的注册接口REG_PASS,完成自定义把用户自己实现的Pass类注册进MindSpore Lite里。

class PassTutorial : public registry::PassBase {  // 继承基类
 public:
  PassTutorial() : PassBase("PassTutorial") {}

  ~PassTutorial() = default;

  bool Execute(const api::FuncGraphPtr &func_graph) override;     // 重载接口

 private:
  AnfNodePtr CreateCustomOp(const api::FuncGraphPtr func_graph, const CNodePtr &cnode);
};

using mindspore::registry::POSITION_BEGIN;            // 选择调度位置
REG_PASS(PassTutorial, opt::PassTutorial)             // 注册扩展类
REG_SCHEDULED_PASS(POSITION_BEGIN, {"PassTutorial"})  // 注册调度逻辑

示例代码可参考pass

在离线转换阶段,我们会对模型的每一个节点的输出张量进行推断,包括输出张量的Format、DataType以及Shape,因此,离线转换阶段,用户需提供自己实现的算子的推断过程,这里用户可以参考算子Infershape扩展说明,示例代码可参考infer

示例演示

编译

  • 环境要求

    • 系统环境:Linux x86_64,推荐使用Ubuntu 18.04.02LTS

    • 编译依赖:

  • 编译准备

    MindSpore Lite的发布件不会提供其他框架下的序列化文件,因此,用户需自行编译获得,请参考概述

    本示例采用的是tflite模型,用户需编译flatbuffers,从MindSpore仓中获取TFLITE原型文件,最终生成tflite的序列化文件。

    mindspore/lite/examples/converter_extend目录下创建schema文件目录,继而将生成的序列化文件置于schema目录下。

  • 编译构建

    mindspore/lite/examples/converter_extend目录下执行build.sh,将自动下载MindSpore Lite发布件并编译Demo。

    bash build.sh
    

    若使用该build脚本下载MindSpore Lite发布件失败,请手动下载硬件平台为CPU、操作系统为Ubuntu-x64的MindSpore Lite发布件mindspore-lite-{version}-linux-x64.tar.gz,将解压后tools/converter/lib目录、tools/converter/include目录拷贝到mindspore/lite/examples/converter_extend目录下。

    通过手动下载并且将文件放到指定位置后,需要再次执行build.sh脚本才能完成编译构建。

  • 编译输出

    mindspore/lite/examples/converter_extend/build目录下生成了libconverter_extend_tutorial.so的动态库。

执行程序

  1. 拷贝动态库

    将生成的libconverter_extend_tutorial.so动态库文件拷贝到发布件的tools/converter/lib下。

  2. 进入发布件的转换目录

    cd ${PACKAGE_ROOT_PATH}/tools/converter/converter
    
  3. 创建converter的配置文件(converter.cfg,详细可参考扩展配置),文件内容如下:

    [registry]
    plugin_path=libconverter_extend_tutorial.so      # 用户请配置动态库的正确路径
    
  4. 将转换工具需要的动态链接库加入环境变量LD_LIBRARY_PATH

    export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/tools/converter/lib
    
  5. 执行converter

    ./converter_lite --fmk=TFLITE --modelFile=add.tflite --configFile=converter.cfg --outputFile=add_extend
    

执行完后,将生成名为add_extend.ms的模型文件,文件路径由参数outputFile决定。

扩展配置

在转换阶段,为了能够加载扩展模块,用户需要配置扩展动态库路径。扩展相关的参数有plugin_pathdisable_fusion。参数的详细介绍如下所示:

参数

属性

功能描述

参数类型

默认值

取值范围

plugin_path

可选

第三方库加载路径

String

-

如有多个请用;分隔

disable_fusion

可选

是否关闭融合优化

String

off

off、on

fusion_blacklists

可选

关闭指定融合算子名称

String

-

如有多个请用,分隔

发布件中已为用户生成好默认的配置文件(converter.cfg)。该配置文件内容如下:

[registry]
plugin_path=libconverter_extend_tutorial.so      # 用户请配置动态库的正确路径

如果用户需要关闭指定算子融合优化,关闭指定名单融合配置如下所示:

[registry]
# 当参数disable_fusion=off时,可通过配置fusion_blacklists关闭指定融合优化,当参数disable_fusion=on时,关闭所有融合优化,参数fusion_blacklists不生效。
disable_fusion=off
fusion_blacklists=ConvActivationFusion,MatMulActivationFusion

融合算子名单如下所示:

序号

融合算子名称

1

AddConcatActivationFusion

2

SqueezeFusion

3

TransposeFusion

4

ReshapeReshapeFusion

5

ConvBiasaddFusion

6

ConvBatchNormFusion

7

ConvScaleFusion

8

GroupNormFusion

9

TfNormFusion

10

OnnxLayerNormFusion

11

OnnxLayerNormFusion2

12

BatchMatMulFusion

13

BatchNormToScaleFusion

14

SigmoidMulFusion

15

ActivationFusion

16

ConvActivationFusion

17

ConvTupleGetItemFusion

18

ConvTupleActivationFusion

19

TfliteLstmCellFusion

20

TfLstmCellFusion

21

TfBidirectionGruFusion

22

TfGeLUFusion

23

OnnxGeLUFusion

24

TfliteRelPosMultiHeadAttentionFusion

25

GLUFusion

26

ConstFoldPass

27

AffineFusion

28

AffineActivationFusion

29

ConvConvFusion

30

ConvPadFusion

31

MatMulAddFusion

32

MatMulMulFusion

33

TransposeMatMulFusion

34

MulAddFusion

35

ScaleActivationFusion

36

ScaleScaleFusion

37

FullConnectedFusion

38

FullconnectedAddFusion

39

TensorDotFusion

40

MatMulActivationFusion