mindspore::session

查看源文件

LiteSession

#include <lite_session.h>

LiteSession定义了MindSpore Lite中的会话,用于进行Model的编译和前向推理。

构造函数和析构函数

LiteSession

LiteSession()

MindSpore Lite LiteSession的构造函数,使用默认参数。

~LiteSession

~LiteSession()

MindSpore Lite LiteSession的析构函数。

公有成员函数

BindThread

virtual void BindThread(bool if_bind)

尝试将线程池中的线程绑定到指定的cpu内核,或从指定的cpu内核进行解绑。

  • 参数

    • if_bind: 定义了对线程进行绑定或解绑。

CompileGraph

virtual int CompileGraph(lite::Model *model)

编译MindSpore Lite模型。

CompileGraph必须在RunGraph方法之前调用。

  • 参数

    • model: 定义了需要被编译的模型。

  • 返回值

    STATUS,即编译图的错误码。STATUS在errorcode.h中定义。

GetInputs

virtual std::vector <tensor::MSTensor *> GetInputs() const

获取MindSpore Lite模型的MSTensors输入。

  • 返回值

    MindSpore Lite MSTensor向量。

GetInputsByTensorName

mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &name) const

通过tensor名获取MindSpore Lite模型的MSTensors输入。

  • 参数

    • name: 定义了tensor名。

  • 返回值

    MindSpore Lite MSTensor。

RunGraph

virtual int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr)

运行带有回调函数的会话。

RunGraph必须在CompileGraph方法之后调用。

  • 参数

    • before: 一个KernelCallBack 结构体。定义了运行每个节点之前调用的回调函数。

    • after: 一个KernelCallBack 结构体。定义了运行每个节点之后调用的回调函数。

  • 返回值

    STATUS ,即编译图的错误码。STATUS在errorcode.h中定义。

GetOutputsByNodeName

virtual std::vector <tensor::MSTensor *> GetOutputsByNodeName(const std::string &node_name) const

通过节点名获取MindSpore Lite模型的MSTensors输出。

  • 参数

    • node_name: 定义了节点名。

  • 返回值

    MindSpore Lite MSTensor向量。

GetOutputs

virtual std::unordered_map <std::string, mindspore::tensor::MSTensor *> GetOutputs() const

获取与张量名相关联的MindSpore Lite模型的MSTensors输出。

  • 返回值

    包含输出张量名和MindSpore Lite MSTensor的容器类型变量。

GetOutputTensorNames

virtual std::vector <std::string> GetOutputTensorNames() const

获取由当前会话所编译的模型的输出张量名。

  • 返回值

    字符串向量,其中包含了按顺序排列的输出张量名。

GetOutputByTensorName

virtual mindspore::tensor::MSTensor *GetOutputByTensorName(const std::string &tensor_name) const

通过张量名获取MindSpore Lite模型的MSTensors输出。

  • 参数

    • tensor_name: 定义了张量名。

  • 返回值

    指向MindSpore Lite MSTensor的指针。

Resize

virtual int Resize(const std::vector <tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims)

调整输入的形状。

  • 参数

    • inputs: 模型对应的所有输入。

    • dims: 输入对应的新的shape,顺序注意要与inputs一致。

  • 返回值

    STATUS,即编译图的错误码。STATUS在errorcode.h中定义。

Train

virtual int Train() = 0;

设置为训练模式。

  • 返回值

    STATUS,即编译图的错误码。STATUS在errorcode.h中定义。

IsTrain

bool IsTrain() { return train_mode_ == true; }

检查当前模型是否为训练模式。

  • 返回值

    true 或 false,即当前模型是否为训练模式。

Eval

virtual int Eval() = 0;

设置为验证模式。

  • 返回值

    STATUS,即编译图的错误码。STATUS在errorcode.h中定义。

IsEval

bool IsEval() { return train_mode_ == false; }

检查当前模型是否为验证模式。

  • 返回值

    true 或 false,即当前模型是否为验证模式。

SetLearningRate

virtual int SetLearningRate(float learning_rate) = 0;

为当前模型设置学习率。

  • 返回值

    STATUS,即编译图的错误码。STATUS在errorcode.h中定义。

GetLearningRate

virtual float GetLearningRate() = 0;

获取当前模型的学习率。

  • 返回值

    当前模型的学习率, 如果未设置优化器则返回0.0。

SetupVirtualBatch

virtual int SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f) = 0;

用户自定义虚拟批次数,,用于减少内存消耗。

  • 参数

    • virtual_batch_multiplier: 自定义虚拟批次数。

    • lr: 自定义学习率。

    • momentum: 自定义动量。

  • 返回值

    STATUS,即编译图的错误码。STATUS在errorcode.h中定义。

GetPredictions

virtual std::vector<tensor::MSTensor *> GetPredictions() const = 0;

获取训练模型的预测结果。

  • 返回值

    预测结果张量指针数组。

Export

virtual int Export(const std::string &file_name, lite::ModelType model_type = lite::MT_TRAIN,
                     lite::QuantizationType quant_type = lite::QT_DEFAULT, lite::FormatType format= lite::FT_FLATBUFFERS) const = 0;

保存已训练模型。

  • 参数

    • filename: 保存模型的文件名。

    • model_type: 训练或推理。

    • quant_type: 量化类型。

    • format: 保存模型格式。

  • 返回值

    STATUS,即编译图的错误码。STATUS在errorcode.h中定义。

GetFeatureMaps

 virtual std::vector<tensor::MSTensor *> GetFeatureMaps() const = 0;

获取训练模型权重。

  • 返回值

    权重列表。

UpdateFeatureMaps

 virtual int UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features) = 0;

更新训练模型权重。

  • 参数

    • features: 新的权重列表。

  • 返回值

    STATUS,即编译图的错误码。STATUS在errorcode.h中定义。

静态公有成员函数

CreateSession

static LiteSession *CreateSession(const lite::Context *context)

用于创建一个LiteSession指针的静态方法。

  • 参数

    • context: 定义了所要创建的session的上下文。

  • 返回值

    指向MindSpore Lite LiteSession的指针。

static LiteSession *CreateSession(const char *model_buf, size_t size, const lite::Context *context);

用于创建一个LiteSession指针的静态方法。返回的Lite Session指针已经完成了model_buf的读入和图编译。

  • 参数

    • model_buf: 定义了读取模型文件的缓存区。

    • size: 定义了模型缓存区的字节数。

    • context: 定义了所要创建的session的上下文。

  • 返回值

    指向MindSpore Lite LiteSession的指针。

TrainSession

#include <train_session.h>

TrainSession定义了MindSpore Lite 训练过程中的会话,用于进行Model的编译和训练。

构造函数和析构函数

TrainSession

TrainSession()

MindSpore Lite TrainSession的构造函数,使用默认参数。

~TrainSession

~TrainSession()

MindSpore Lite TrainSession的析构函数。

公有成员函数

CreateTransferSession

static TrainSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head, const lite::Context *context, bool train_mode = false, const lite::TrainCfg *cfg = nullptr);

创建迁移学习训练会话指针的静态方法。

  • 参数

    • filename_backbone: 主干网络的名称。

    • filename_head: 顶层网络的名称。

    • context: 指向目标会话的指针。

    • train_mode: 是否开启训练模式。

    • cfg: 训练相关配置。

  • 返回值

    指向训练会话的指针。

CreateTrainSession

static LiteSession *CreateTrainSession(const std::string &filename, const lite::Context *context, bool train_mode = false, const lite::TrainCfg *cfg = nullptr);

创建训练会话指针的静态方法。

  • 参数

    • filename: 指向文件名称。

    • context: 指向会话指针

    • train_mode: 是否开启训练模式。

    • cfg: 训练相关配置。

  • 返回值

    指向训练会话的指针。

TrainLoop

#include <ltrain_loop.h>

继承于Session,可设置训练参数和数据预处理函数,用于减少模型训练的资源消耗。

构造函数和析构函数

~TrainLoop

virtual ~TrainLoop() = default;

虚析构函数。

公有成员函数

CreateTrainLoop

static TrainLoop *CreateTrainLoop(session::TrainSession *train_session, lite::Context *context, int batch_size = -1);

创建迭代训练指针的静态方法。

  • 参数

    • model_filename: 模型文件名。

    • context: 指向目标会话的指针。

    • batch_size: 批次数。

  • 返回值

    指向迭代训练对象的指针。

Reset

virtual int Reset() = 0;

重置迭代次数为0。

  • 返回值

    STATUS,即编译图的错误码。STATUS在errorcode.h中定义。

train_session

virtual session::TrainSession *train_session() = 0;

获取TrainSession会话对象。

  • 返回值

    指向训练会话对象的指针。

Init

virtual int Init(std::vector<mindspore::session::Metrics *> metrics) = 0;

初始化模型评估矩阵。

  • 参数

    • metrics: 模型评估矩阵指针数组。

  • 返回值

    STATUS,即编译图的错误码。STATUS在errorcode.h中定义。

GetMetrics

virtual std::vector<mindspore::session::Metrics *> GetMetrics() = 0;

获取模型评估矩阵。

  • 返回值

    模型评估矩阵指针数组。

SetKernelCallBack

virtual int SetKernelCallBack(const KernelCallBack &before, const KernelCallBack &after) = 0;

设置运行时回调函数。

  • 参数

    • before: 执行前回调。

    • after: 执行后回调。

  • 返回值

    STATUS,即编译图的错误码。STATUS在errorcode.h中定义。

Train

virtual int Train(int epochs, mindspore::dataset::Dataset *dataset, std::vector<TrainLoopCallBack *> cbs, LoadDataFunc load_func = nullptr)= 0;

执行迭代训练。

  • 参数

    • epochs: 迭代次数。

    • dataset: 指向MindData类对象的指针。

    • cbs: 对象指针数组。

    • load_func: 类模板函数对象。

  • 返回值

    STATUS,即编译图的错误码。STATUS在errorcode.h中定义。

Eval

virtual int Eval(mindspore::dataset::Dataset *dataset, std::vector<TrainLoopCallBack *> cbs, LoadDataFunc load_func = nullptr, int max_steps = INT_MAX) = 0;

执行推理。

  • 参数

    • dataset: 指向MindData类对象的指针。

    • cbs: 对象指针数组。

    • load_func: 类模板函数对象。

    • max_steps: 重复迭代次数。

  • 返回值

    STATUS,即编译图的错误码。STATUS在errorcode.h中定义。

TrainLoopCallback

#include <ltrain_loop_callback.h>

在模型训练中执行回调函数。

构造函数和析构函数

~TrainLoopCallback

virtual ~TrainLoopCallback() = default;

析构函数。

Public Member Functions

Begin

virtual void Begin(const TrainLoopCallBackData &cb_data) {}

在模型训练前执行。

  • 参数

    • cb_data: 回调函数对象。

End

virtual void End(const TrainLoopCallBackData &cb_data) {}

在模型训练后执行回调。

  • 参数

    • cb_data: 回调函数对象。

EpochBegin

virtual void EpochBegin(const TrainLoopCallBackData &cb_data) {}

每次迭代开始前执行回调。

  • 参数

    • cb_data: 回调函数对象。

EpochEnd

virtual int EpochEnd(const TrainLoopCallBackData &cb_data) { return RET_CONTINUE; }

每次迭代结束后执行回调。

  • 参数

    • cb_data: 回调函数对象。

  • 返回 STATUS,即编译图的错误码。STATUS在errorcode.h中定义。

StepBegin

virtual void StepBegin(const TrainLoopCallBackData &cb_data) {}

每一步开始前执行回调。

  • 参数

    • cb_data: 回调函数对象。

StepEnd

virtual void StepEnd(const TrainLoopCallBackData &cb_data) {}

每一步开始后执行回调。

  • 参数

    • cb_data: 回调函数对象。

Metrics

#include <metrics.h>

训练模型评估矩阵类

构造函数和析构函数

~Metrics

virtual ~Metrics() = default;

析构函数。

Public Member Functions

Clear

virtual void Clear() {}

将成员变量total_accuracy_total_steps_置为零。

Eval

virtual float Eval() {}

评估模型。

Update

virtual void Update(std::vector<tensor::MSTensor *> inputs, std::vector<tensor::MSTensor *> outputs) = 0;

更新成员变量total_accuracy_total_steps_的值。