mindspore::kernel
接口汇总
类名 |
描述 |
---|---|
算子基类。 |
|
算子扩展能力基类。 |
Kernel
#include <kernel.h>
Kernel是算子实现的基类,定义了几个必须实现的接口。
构造函数
Kernel
Kernel()
Kernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
const schema::Primitive *primitive, const mindspore::Context *ctx)
Kernel的默认与带参构造函数,构造Kernel实例。
析构函数
~Kernel
virtual ~Kernel()
Kernel的析构函数。
公有成员函数
Prepare
virtual int Prepare()
进行算子运行前相关的准备工作,MindSpore Lite 框架运行时会对所有算子执行一遍Prepare后再执行Execute。
Execute
virtual int Execute()
运行算子。
ReSize
virtual int ReSize()
在用户调用Model::Resize
接口时,或是模型推理中需要重新推理算子形状时,会调用到该接口。
在ReSize
函数中,若有必要,根据输入的形状态重新推理输出形状,并分配算子运算中需要的内存。
type
virtual schema::PrimitiveType type()
返回算子的类型。
set_inputs
virtual void set_inputs(const std::vector<mindspore::MSTensor> &in_tensors)
设置算子的输入列表。
参数
in_tensors
: 算子的所有输入MSTensor列表。
set_input
virtual set_input(mindspore::MSTensor in_tensor, int index)
设置算子指定位置的输入。
参数
in_tensor
: 算子的输入MSTensor。index
: 算子输入在所有输入中的下标,从0开始计数。
set_outputs
virtual void set_outputs(const std::vector<mindspore::MSTensor> &out_tensors)
设置算子的输出列表。
参数
out_tensor
: 算子的所有输出MSTensor列表。
set_output
virtual void set_output(mindspore::MSTensor out_tensor, int index)
设置算子指定位置的输出。
参数
out_tensor
: 算子的输出MSTensor。index
: 算子输出在所有输出中的下标,从0开始计数。
inputs
virtual const std::vector<mindspore::MSTensor *> &inputs()
返回算子的所有输入MSTensor列表。
outputs
virtual const std::vector<mindspore::MSTensor *> &outputs()
返回算子的所有输出MSTensor列表。
name
std::string name()
返回算子的名称。
set_name
void set_name(const std::string &name)
设置算子的名称。
参数
name
: 算子名称。
context
const lite::Context *context() const
返回算子对应的Context。
primitive
const schema::Primitive *primitive() const
返回算子经由flatbuffers反序化为Primitive后的结果。
GetAttr
std::string GetAttr(const std::string &key) const
获取指定配置名对应的配置。
参数
key
: 配置名。
SetConfig
void SetConfig(const std::map<std::string, std::map<std::string, std::string>> *config)
保存配置内容的常量指针到kernel里,该接口当前是由框架在加载配置文件时自动触发调用的,不建议用户使用。
参数
config
: 配置的常量指针。
GetConfig
std::map<std::string, std::string> GetConfig(const std::string §ion) const
获取指定章节名对应的配置。
参数
section
: 配置的章节名称。
KernelInterface
#include <kernel_interface.h>
算子扩展能力基类。
~KernelInterface
virtual ~KernelInterface()
析构函数。
KernelInterfaceCreator
using KernelInterfaceCreator = std::function<std::shared_ptr<KernelInterface>()>
创建KernelInterface的函数原型声明。
公有成员函数
Infer
算子的InferShape能力,用于根据输入推导出输出的形状、数据类型以及format。
virtual int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs, const schema::Primitive *primitive, const Kernel *kernel)
Infer
算子的InferShape能力,用于根据输入推导出输出的shape、数据类型以及format。
该接口已不推荐使用,建议使用带有kernel参数的Infer接口。因为如果模型通过以下Build接口执行编译,编译后框架会自动释放模型的内存,导致primitive不可用。
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr,
const std::shared_ptr<TrainCfg> &train_cfg = nullptr)
virtual int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs, const schema::Primitive *primitive)