mindspore::kernel
接口汇总
类名 |
描述 |
---|---|
算子基类。 |
|
算子扩展能力基类。 |
Kernel
#include <kernel.h>
Kernel是算子实现的基类,定义了几个必须实现的接口。
构造函数
Kernel
Kernel()
Kernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
const schema::Primitive *primitive, const mindspore::Context *ctx)
Kernel的默认与带参构造函数,构造Kernel实例。
析构函数
~Kernel
virtual ~Kernel()
Kernel的析构函数。
公有成员函数
Prepare
virtual int Prepare()
进行算子运行前相关的准备工作,MindSpore Lite 框架运行时会对所有算子执行一遍Prepare后再执行Execute。
Execute
virtual int Execute()
运行算子。
ReSize
virtual int ReSize()
根据输入的形状态重新分配算子需要的内存。
type
virtual schema::PrimitiveType type()
返回算子的类型。
set_inputs
virtual void set_inputs(const std::vector<mindspore::MSTensor> &in_tensors)
设置算子的输入列表。
参数
in_tensors
: 算子的所有输入MSTensor列表。
set_input
virtual set_input(mindspore::MSTensor in_tensor, int index)
设置算子指定位置的输入。
参数
in_tensor
: 算子的输入MSTensor。index
: 算子输入在所有输入中的下标,从0开始计数。
set_outputs
virtual void set_outputs(const std::vector<mindspore::MSTensor> &out_tensors)
设置算子的输出列表。
参数
out_tensor
: 算子的所有输出MSTensor列表。
set_output
virtual void set_output(mindspore::MSTensor out_tensor, int index)
设置算子指定位置的输出。
参数
out_tensor
: 算子的输出MSTensor。index
: 算子输出在所有输出中的下标,从0开始计数。
inputs
virtual const std::vector<mindspore::MSTensor *> &inputs()
返回算子的所有输入MSTensor列表。
outputs
virtual const std::vector<mindspore::MSTensor *> &outputs()
返回算子的所有输出MSTensor列表。
name
std::string name()
返回算子的名称。
set_name
void set_name(const std::string &name)
设置算子的名称。
参数
name
: 算子名称。
context
const lite::Context *context() const
返回算子对应的Context。
primitive
const schema::Primitive *primitive() const
返回算子经由flatbuffers反序化为Primitive后的结果。
GetAttr
std::string GetAttr(const std::string &key) const
参数
key
: 配置名。
获取指定配置名对应的配置。
KernelInterface
#include <kernel_interface.h>
算子扩展能力基类。
~KernelInterface
virtual ~KernelInterface()
析构函数。
KernelInterfaceCreator
using KernelInterfaceCreator = std::function<std::shared_ptr<KernelInterface>()>
创建KernelInterface的函数原型声明。
公有成员函数
Infer
算子的InferShape能力,用于根据输入推导出输出的shape、数据类型以及format。
virtual int Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs, const schema::Primitive *primitive)