Model
import com.mindspore.model;
Model defines model in MindSpore for compiling and running.
Public Member Functions
build
public boolean build(Graph graph, MSContext context, TrainCfg cfg)
Compile MindSpore model by computational graph.
Parameters
graph
: computational graph.context
: compile context.cfg
: train config.
Returns
Whether the build is successful.
public boolean build(MappedByteBuffer buffer, int modelType, MSContext context, char[] dec_key, String dec_mode)
Compile MindSpore model by computational graph buffer.
Parameters
buffer
: computational graph buffer.modelType
: computational graph type, optional MindIR, ONNX.context
: compile context.dec_key
: define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.dec_mode
: define the decryption mode. Options: AES-GCM, AES-CBC.
Returns
Whether the build is successful.
public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context)
Compile MindSpore model by computational graph buffer, the default is MindIR model type.
Parameters
buffer
: computational graph buffer.modelType
: computational graph type, optional MindIR, ONNX.context
: compile context.
Returns
Whether the build is successful.
public boolean build(String modelPath, int modelType, MSContext context, char[] dec_key, String dec_mode)
Compile MindSpore model by computational graph file.
Parameters
modelPath
: computational graph file.modelType
: computational graph type, optional MindIR, ONNX.context
: compile context.dec_key
: define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.dec_mode
: define the decryption mode. Options: AES-GCM, AES-CBC.
Returns
Whether the build is successful.
public boolean build(String modelPath, int modelType, MSContext context)
Compile MindSpore model by computational graph file,no decrypt.
Parameters
modelPath
: computational graph file.modelType
: computational graph type, optional MindIR, ONNX.context
: compile context.
Returns
Whether the build is successful.
predict
public boolean predict()
Run predict.
Returns
Whether the predict is successful.
runStep
public boolean runStep()
Run train by step.
Returns
Whether the run is successful.
resize
public boolean resize(List<MSTensor> inputs, int[][] dims)
Resize inputs shape.
Parameters
inputs
: Model inputs.dims
: Define the new inputs shape.
Returns
Whether the resize is successful.
getInputs
public List<MSTensor> getInputs()
Get the input MSTensors of MindSpore model.
Returns
The MindSpore MSTensor list.
getOutputs
public List<MSTensor> getOutputs()
Get the output MSTensors of MindSpore model.
Returns
The MindSpore MSTensor list.
getInputsByTensorName
public MSTensor getInputsByTensorName(String tensorName)
Get the input MSTensors of MindSpore model by the node name.
Parameters
tensorName
: Define the tensor name.
Returns
MindSpore MSTensor.
getOutputsByNodeName
public List<MSTensor> getOutputsByNodeName(String nodeName)
Get the output MSTensors of MindSpore model by the node name.
Parameters
nodeName
: Define the node name.
Returns
The MindSpore MSTensor list.
getOutputTensorNames
public List<String> getOutputTensorNames()
Get output tensors names of the model compiled by this session.
Returns
The vector of string as output tensor names in order.
getOutputByTensorName
public MSTensor getOutputByTensorName(String tensorName)
Get the MSTensors output of MindSpore model by the tensor name.
Parameters
tensorName
: Define the tensor name.
Returns
MindSpore MSTensor.
export
public boolean export(String fileName, int quantizationType, boolean isOnlyExportInfer,List<String> outputTensorNames)
Export the model.
Parameters
fileName
: Model file name.quantization_type
: The quant type.isOnlyExportInfer
: Is only export infer.outputTensorNames
: The output tensor names for export.
Returns
Whether the export is successful.
getFeatureMaps
public List<MSTensor> getFeatureMaps()
Get the FeatureMap.
Returns
FeatureMaps tensor list.
updatefeaturemaps
public boolean updateFeatureMaps(List<MSTensor> features)
Update model Features.
Parameters
features
: New featureMaps tensor List.
Returns
Whether the model features is successfully update.
settrainMode
public boolean setTrainMode(boolean isTrain)
Set train mode.
Parameters
isTrain
: Is train mode.
gettrainmode
public boolean getTrainMode()
Get train mode.
Returns
Whether the model work in train mode.
setLearningRate
public boolean setLearningRate(float learning_rate)
set learning rate.
Parameters
learning_rate
: learning rate.
Returns
Whether the set learning rate is successful.
setupVirtualBatch
public boolean setupVirtualBatch(int virtualBatchMultiplier, float learningRate, float momentum)
Set the virtual batch.
Parameters
virtualBatchMultiplier
: virtual batch multuplier.learningRate
: learning rate.momentum
: monentum.
Returns
Whether the virtual batch is successfully set.
free
public void free()
Free Model.
ModelType
import com.mindspore.config.ModelType;
Model file type.
public static final int MT_MINDIR = 0;
public static final int MT_AIR = 1;
public static final int MT_OM = 2;
public static final int MT_ONNX = 3;
public static final int MT_MINDIR_OPT = 4;