MindSpore Golden Stick Network Conversion
There are three ways to convert the model and export MindIR:
Export MindIR after training;
Export MindIR from ckpt;
Configure the algorithm before training to automatically export MindIR.
Necessary Prerequisites
Firstly download dataset and create Lenet, and for demonstration convenience, we implemented one simplest MindSpore Golden Stick algorithm, called FooAlgo.
[ ]:
import os
import numpy as np
from download import download
import mindspore
from mindspore import nn, Model, Tensor, export
from mindspore.train import Accuracy
from mindspore.train import ModelCheckpoint
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype
from mindspore.common.initializer import Normal
from mindspore_gs import CompAlgo
# Download data from open datasets
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)
def create_dataset(data_path, batch_size=32, num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# define map operations
resize_op = vision.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = vision.Rescale(rescale_nml * rescale, shift_nml)
hwc2chw_op = vision.HWC2CHW()
type_cast_op = transforms.TypeCast(mstype.int32)
# apply map operations on images
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
# apply DatasetOps
mnist_ds = mnist_ds.shuffle(buffer_size=1024)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
return mnist_ds
train_dataset = create_dataset("MNIST_Data/train", 32, 1)
print("train dataset output shape: ", train_dataset.output_shapes())
# initial network
class LeNet5(nn.Cell):
def __init__(self, num_class=10, num_channel=1, include_top=True):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.include_top = include_top
if self.include_top:
self.flatten = nn.Flatten()
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
if not self.include_top:
return x
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# set graph mode
mindspore.set_context(mode=mindspore.GRAPH_MODE)
# for demonstration convenience, we implemented one simplest MindSpore Golden Stick algorithm, called FooAlgo
class FooAlgo(CompAlgo):
def apply(self, network: nn.Cell) -> nn.Cell:
return network
print("init ok.")
train dataset output shape: [[32, 1, 32, 32], [32]]
init ok.
Exporting MindIR After Training
MindSpore Golden Stick algorithms provide a convert
interface to convert network, and then you can use mindspore.export
to export MindIR.
[9]:
## 1) Create network and dataset.
network = LeNet5(10)
train_dataset = create_dataset("MNIST_Data/train", 32, 1)
## 2) Create an algorithm instance.
algo = FooAlgo()
## 3) Apply MindSpore Golden Stick algorithm to origin network.
network_opt = algo.apply(network)
## 4) Set up Model.
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network_opt.trainable_params(), 0.01, 0.9)
model = Model(network_opt, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
cbs = [ModelCheckpoint(prefix='network', directory='ckpt/')]
## 5) Config callback in model.train, start training.
cbs.extend(algo.callbacks())
model.train(1, train_dataset, callbacks=cbs)
## 6) Convert network.
net_deploy = algo.convert(network_opt)
## 7) Export MindIR
inputs = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32)) # user define
export(net_deploy, inputs, file_name="net_1.mindir", file_format="MINDIR")
## 8) Test MindIR
file_path = "./net_1.mindir"
file_path = os.path.realpath(file_path)
if not os.path.exists(file_path):
print("Export MindIR failed!!!")
else:
print("Export MindIR success! MindIR path is: ", file_path)
test_inputs = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32))
graph = mindspore.load(file_path)
net = nn.GraphCell(graph)
output = net(test_inputs)
print("Test output MindIR success, result shape is: ", output.shape)
Export MindIR success! MindIR path is: /home/workspace/golden_stick/net_1.mindir
Test output MindIR success, result shape is: (32, 10)
Export from ckpt
Using the ckpt file after training, call convert
and mindspore.export
interfaces to export MindIR.
Please run the sample code in the previous section first, this section requires the ckpt file generated by the training process in the previous section.
[ ]:
## 1) Create network and dataset.
network = LeNet5(10)
train_dataset = create_dataset("MNIST_Data/train", 32, 1)
## 2) Create an algorithm instance.
algo = FooAlgo()
## 3) Apply MindSpore Golden Stick algorithm to origin network.
network_opt = algo.apply(network)
## 4) Convert network.
net_deploy = algo.convert(network_opt, ckpt_path="ckpt/network-1_1875.ckpt") # ckpt from previous section
## 5) Export MindIR
inputs = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32)) # user define
export(net_deploy, inputs, file_name="net_2.mindir", file_format="MINDIR")
## 6) Test MindIR
file_path = "./net_2.mindir"
file_path = os.path.realpath(file_path)
if not os.path.exists(file_path):
print("Export MindIR failed!!!")
else:
print("Export MindIR success! MindIR path is: ", file_path)
test_inputs = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32))
graph = mindspore.load(file_path)
net = nn.GraphCell(graph)
output = net(test_inputs)
print("Test output MindIR success, result shape is: ", output.shape)
Export MindIR success! MindIR path is: /home/workspace/golden_stick/net_2.mindir
Test output MindIR success, result shape is: (32, 10)
Configuring the Algorithm to Automatically Export MindIR
Configure the algorithm set_save_mindir
interface before training to automatically export MindIR after training.
When using MindIR generated in this way for inference, the input shape of the MindIR must be consistent with the shape of the dataset at the time of training.
There are two necessary operations to configure the algorithm to automatically export MindIR,
set_save_mindir(True)
and add the algorithm callback functioncallbacks=algo.callbacks()
when configuring the callback function inmodel.train
. MindIR output pathsave_mindir_path
is saved by default as./network.mindir
if not configured.
[ ]:
## 1) Create network and dataset.
network = LeNet5(10)
train_dataset = create_dataset("MNIST_Data/train", 32, 1)
## 2) Create an algorithm instance.
algo = FooAlgo()
## 3) Enable automatically export MindIR after training.
algo.set_save_mindir(save_mindir=True)
## 4) Set MindIR output path, the default value for the path is 'network.mindir'.
algo.set_save_mindir_path(save_mindir_path="net_3.mindir")
## 5) Apply MindSpore Golden Stick algorithm to origin network.
network_opt = algo.apply(network)
## 6) Set up Model.
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network_opt.trainable_params(), 0.01, 0.9)
model = Model(network_opt, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
## 7) Config callback in model.train, start training, then MindIR will be exported.
model.train(1, train_dataset, callbacks=algo.callbacks())
## 8) Test MindIR
file_path = "./net_3.mindir"
file_path = os.path.realpath(file_path)
if not os.path.exists(file_path):
print("Export MindIR failed!!!")
else:
print("Export MindIR success! MindIR path is: ", file_path)
test_inputs = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32))
graph = mindspore.load(file_path)
net = nn.GraphCell(graph)
output = net(test_inputs)
print("Test output MindIR success, result shape is: ", output.shape)
Export MindIR success! MindIR path is: /home/workspace/golden_stick/net_3.mindir
Test output MindIR success, result shape is: (32, 10)