LiteSession
import com.mindspore.lite.LiteSession;
LiteSession定义了MindSpore Lite中的会话,用于进行Model的编译和前向推理。
公有成员函数
function |
---|
boolean export(String modelFilename, int model_type, int quantization_type) |
boolean isTrain() |
boolean isEval() |
boolean setLearningRate(float learning_rate) |
boolean setupVirtualBatch(int virtualBatchMultiplier, float learningRate, float momentum) |
List |
boolean updateFeatures(List |
init
public boolean init(MSConfig config)
初始化LiteSession。
参数
MSConfig
: MSConfig类。
返回值
初始化是否成功。
bindThread
public void bindThread(boolean if_bind)
尝试将线程池中的线程绑定到指定的CPU内核,或从指定的CPU内核进行解绑。
参数
if_bind
: 是否对线程进行绑定或解绑。
compileGraph
public boolean compileGraph(Model model)
编译MindSpore Lite模型。
参数
Model
: 需要被编译的模型。
返回值
编译是否成功。
runGraph
public boolean runGraph()
运行图进行推理。
返回值
推理是否成功。
getInputs
public List<MSTensor> getInputs()
获取MindSpore Lite模型的MSTensors输入。
返回值
所有输入MSTensor组成的List。
getInputsByTensorName
public MSTensor getInputByTensorName(String tensorName)
通过节点名获取MindSpore Lite模型的MSTensors输入。
参数
tensorName
: 张量名。
返回值
tensorName所对应的输入MSTensor。
getOutputsByNodeName
public List<MSTensor> getOutputsByNodeName(String nodeName)
通过节点名获取MindSpore Lite模型的MSTensors输出。
参数
nodeName
: 节点名。
返回值
该节点所有输出MSTensor组成的List。
getOutputMapByTensor
public Map<String, MSTensor> getOutputMapByTensor()
获取与张量名相关联的MindSpore Lite模型的MSTensors输出。
返回值
输出张量名和MSTensor的组成的Map。
getOutputTensorNames
public List<String> getOutputTensorNames()
获取由当前会话所编译的模型的输出张量名。
返回值
按顺序排列的输出张量名组成的List。
getOutputByTensorName
public MSTensor getOutputByTensorName(String tensorName)
通过张量名获取MindSpore Lite模型的MSTensors输出。
参数
tensorName
: 张量名。
返回值
该张量所对应的MSTensor。
resize
public boolean resize(List<MSTensor> inputs, int[][] dims)
调整输入的形状。
参数
inputs
: 模型对应的所有输入。dims
: 输入对应的新的shape,顺序注意要与inputs一致。
返回值
调整输入形状是否成功。
free
public void free()
释放LiteSession。
export
public boolean export(String modelFilename, int model_type, int quantization_type)
导出模型。
参数
modelFilename
: 模型文件名称。model_type
: 训练或者推理类型。quantization_type
: 量化类型。
返回值
导出模型是否成功。
train
public void train()
切换训练模式。
eval
public void eval()
切换推理模式。
istrain
public void isTrain()
是否训练模式。
iseval
public void isEval()
是否推理模式。
setLearningRate
public boolean setLearningRate(float learning_rate)
设置学习率。
参数
learning_rate
: 学习率。
返回值
学习率设置是否成功。
setupVirtualBatch
public boolean setupVirtualBatch(int virtualBatchMultiplier, float learningRate, float momentum)
设置虚批次系数。
参数
virtualBatchMultiplier
: 虚批次系数。learningRate
: 学习率。momentum
: 动量系数。
返回值
虚批次系数设置是否成功。
getFeaturesMap
public List<MSTensor> getFeaturesMap()
获取权重参数。
返回值
权重参数列表。
updateFeatures
public boolean updateFeatures(List<MSTensor> features)
更新权重参数。
参数
features
: 新的权重参数列表。
返回值
权重是否更新成功。