View Source On Gitee

Introduction || Quick Start || Tensor || Dataset || Transforms || Model || Autograd || Train || Save and Load || Accelerating with Static Graphs

Saving and Loading the Model

The previous section describes how to adjust hyperparameters and train network models. During network model training, we want to save the intermediate and final results for fine-tuning and subsequent model deployment and inference. This section describes how to save and load a model.

import numpy as np
import mindspore
from mindspore import nn
from mindspore import Tensor
def network():
    model = nn.SequentialCell(
                nn.Flatten(),
                nn.Dense(28*28, 512),
                nn.ReLU(),
                nn.Dense(512, 512),
                nn.ReLU(),
                nn.Dense(512, 10))
    return model

Saving and Loading the Model Weight

Saving model by using the save_checkpoint interface, and the specified saving path of passing in the network:

model = network()
mindspore.save_checkpoint(model, "model.ckpt")

To load the model weights, you need to create instances of the same model and then load the parameters by using the load_checkpoint and load_param_into_net methods.

model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)
[]

param_not_load is an unloaded parameter list, and empty means all parameters are loaded successfully.

Saving and Loading MindIR

In addition to Checkpoint, MindSpore provides a unified Intermediate Representation (IR) for cloud side (training) and end side (inference). Models can be saved as MindIR directly by using the export interface.

model = network()
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")

MindIR saves both Checkpoint and model structure, so it needs to define the input Tensor to get the input shape.

The existing MindIR model can be easily loaded through the load interface and passed into nn.GraphCell for inference.

nn.GraphCell only supports graph mode.

mindspore.set_context(mode=mindspore.GRAPH_MODE)

graph = mindspore.load("model.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
print(outputs.shape)
(1, 10)