Ascend910与GPU推理
Ascend
GPU
进阶
推理应用
本文将介绍如何在Ascend910和GPU硬件环境中,利用MindIR和Checkpoint执行推理。MindIR是MindSpore的统一模型文件,同时存储了网络结构和权重参数值,定义了可扩展的图结构以及算子的IR表示,消除了不同后端的模型差异,一般用于跨硬件平台执行推理任务。Checkpoint是训练参数,采用了Protocol Buffers格式,一般用于训练任务中断后恢复训练,或训练后的微调(Fine Tune)任务。
下面将针对这两种情况,介绍如何使用MindSpore进行单卡推理。
使用checkpoint格式文件单卡推理
使用本地模型推理
用户可以通过load_checkpoint
和load_param_into_net
接口从本地加载模型与参数,传入验证数据集后使用model.eval
即可进行模型验证,使用model.predict
可进行模型推理。在这里我们下载MindSpore Hub中已经预训练好的LeNet和MINIST数据集进行推理演示:
以下示例代码将数据集下载并解压到指定位置。
[ ]:
import os
import requests
import tarfile
import zipfile
requests.packages.urllib3.disable_warnings()
def download_dataset(url, target_path):
"""下载并解压数据集"""
if not os.path.exists(target_path):
os.makedirs(target_path)
download_file = url.split("/")[-1]
if not os.path.exists(download_file):
res = requests.get(url, stream=True, verify=False)
if download_file.split(".")[-1] not in ["tgz", "zip", "tar", "gz"]:
download_file = os.path.join(target_path, download_file)
with open(download_file, "wb") as f:
for chunk in res.iter_content(chunk_size=512):
if chunk:
f.write(chunk)
if download_file.endswith("zip"):
z = zipfile.ZipFile(download_file, "r")
z.extractall(path=target_path)
z.close()
if download_file.endswith(".tar.gz") or download_file.endswith(".tar") or download_file.endswith(".tgz"):
t = tarfile.open(download_file)
names = t.getnames()
for name in names:
t.extract(name, target_path)
t.close()
print("The {} file is downloaded and saved in the path {} after processing".format(os.path.basename(url), target_path))
download_dataset("https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-images-idx3-ubyte", "./datasets/MNIST_Data/test")
download_dataset("https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-labels-idx1-ubyte", "./datasets/MNIST_Data/test")
download_dataset("https://download.mindspore.cn/model_zoo/r1.1/lenet_ascend_v111_offical_cv_mnist_bs32_acc98/lenet_ascend_v111_offical_cv_mnist_bs32_acc98.ckpt", "./checkpoint")
完成数据集下载后的文件目录结构如下:
.
├── checkpoint
│ └── lenet_ascend_v111_offical_cv_mnist_bs32_acc98.ckpt
└── datasets
└── MNIST_Data
└── test
├── t10k-images-idx3-ubyte
└── t10k-labels-idx1-ubyte
配置运行所需信息,进行推理的数据处理:
如果在Ascend910环境中运行,下述配置中
device_target="GPU"
的GPU
需改为Ascend
。
[4]:
from mindspore import context
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
# 定义数据集
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# 定义所需要操作的map映射
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# 使用map映射函数,将数据操作应用到数据集
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
# 进行shuffle、batch、repeat操作
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(count=repeat_size)
return mnist_ds
创建LeNet模型:
[5]:
import mindspore.nn as nn
from mindspore.common.initializer import Normal
class LeNet5(nn.Cell):
"""
Lenet网络结构
"""
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
# 实例化网络
net = LeNet5()
在推理进行前,需要使用load_checkpoint
和load_param_into_net
接口从本地加载模型与参数。这样一来就可以使用本地模型完成后面的推理过程。
[6]:
from mindspore import load_checkpoint, load_param_into_net
ckpt_file_name = "./checkpoint/lenet_ascend_v111_offical_cv_mnist_bs32_acc98.ckpt"
param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(net, param_dict)
设置损失函数与优化器,并调用model
接口创建对象:
[7]:
import numpy as np
from mindspore.nn import Accuracy
from mindspore import Model, Tensor
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
下面调用model.eval
接口执行验证过程:
[8]:
mnist_path = "./datasets/MNIST_Data/test"
ds_eval = create_dataset(mnist_path)
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("============== {} ==============".format(acc))
============== {'Accuracy': 0.9846754807692307} ==============
推理完整样例代码参见https://gitee.com/mindspore/models/blob/r1.6/official/cv/lenet/eval.py。
调用model.predict
接口执行验证过程,这里选取数据集中的一张图片进行预测:
被预测的图片数据为随机抽取,实际执行结果可能与文本示例不一致,预测分类与实际分类一致即表示预测结果正确。
[9]:
ds_eval = ds_eval.create_dict_iterator()
data = next(ds_eval)
# images为测试图片,labels为测试图片的实际分类
images = data["image"].asnumpy()
labels = data["label"].asnumpy()
# 使用函数model.predict预测image对应分类
output = model.predict(Tensor(data['image']))
predicted = np.argmax(output.asnumpy(), axis=1)
# 输出预测分类与实际分类
print(f'Predicted: "{predicted[0]}", Actual: "{labels[0]}"')
Predicted: "6", Actual: "6"
加载MindSpore Hub模型执行推理
除了使用load_checkpoint
和load_param_into_net
从本地加载模型之外,也可以通过安装MindSpore Hub,通过mindspore_hub.load
从云端加载模型参数执行推理。
之前使用的加载本地模型的方法为:
from mindspore import load_checkpoint, load_param_into_net
ckpt_file_name = "./checkpoint/lenet_ascend_v111_offical_cv_mnist_bs32_acc98.ckpt"
param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(net, param_dict)
可替换为mindspore_hub.load
方法:
import mindspore_hub
model_uid = "mindspore/ascend/1.2/lenet_v1.2_mnist"
net = mindspore_hub.load(model_uid)
在Ascend中使用C++接口推理MindIR格式文件
本小节将介绍如何使用C++接口推理MINDIR格式的模型。完整代码可参考ascend910_resnet50_preprocess_sample。
本小节内容及代码仅适用于Ascend环境。
推理代码介绍
完成的推理代码为main.cc
文件,现在对其中的功能实现进行说明。
引用mindspore
和mindspore::dataset
的名字空间。
namespace ms = mindspore;
namespace ds = mindspore::dataset;
环境初始化,指定硬件为Ascend 910,DeviceID为0:
auto context = std::make_shared<ms::Context>();
auto ascend910_info = std::make_shared<ms::Ascend910DeviceInfo>();
ascend910_info->SetDeviceID(0);
context->MutableDeviceInfo().push_back(ascend910_info);
加载模型文件:
// 加载 MindIR 模型
ms::Graph graph;
ms::Status ret = ms::Serialization::Load(resnet_file, ms::ModelType::kMindIR, &graph);
// 进行图编译
ms::Model resnet50;
ret = resnet50.Build(ms::GraphCell(graph), context);
获取模型所需输入信息:
std::vector<ms::MSTensor> model_inputs = resnet50.GetInputs();
加载图片文件:
ms::MSTensor ReadFile(const std::string &file);
auto image = ReadFile(image_file);
图片预处理:
// 对图片进行解码,变为RGB格式,并重设尺寸
std::shared_ptr<ds::TensorTransform> decode(new ds::vision::Decode());
std::shared_ptr<ds::TensorTransform> resize(new ds::vision::Resize({256}));
// 输入归一化
std::shared_ptr<ds::TensorTransform> normalize(new ds::vision::Normalize(
{0.485 * 255, 0.456 * 255, 0.406 * 255}, {0.229 * 255, 0.224 * 255, 0.225 * 255}));
// 剪裁图片
std::shared_ptr<ds::TensorTransform> center_crop(new ds::vision::CenterCrop({224, 224}));
// shape (H, W, C) 变为 shape (C, H, W)
std::shared_ptr<ds::TensorTransform> hwc2chw(new ds::vision::HWC2CHW());
// 定义preprocessor
ds::Execute preprocessor({decode, resize, normalize, center_crop, hwc2chw});
// 调用函数,获取处理后的图像
ret = preprocessor(image, &image);
执行推理:
// 创建输入输出向量
std::vector<ms::MSTensor> outputs;
std::vector<ms::MSTensor> inputs;
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
image.Data().get(), image.DataSize());
// 执行推理
ret = resnet50.Predict(inputs, &outputs);
获取推理结果:
// 获取推理结果的最大可能性
std::cout << "Image: " << image_file << " infer result: " << GetMax(outputs[0]) << std::endl;
构建脚本介绍
构建脚本用于构建用户程序,完整代码位于CMakeLists.txt
,下面进行解释说明。
为编译器添加头文件搜索路径:
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
在MindSpore中查找所需动态库:
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
使用指定的源文件生成目标可执行文件,并为目标文件链接MindSpore库:
add_executable(resnet50_sample main.cc)
target_link_libraries(resnet50_sample ${MS_LIB} ${MD_LIB})
编译并执行推理代码
可选择将实验的脚本下载至Ascend910环境中编译并执行。
进入工程目录ascend910_resnet50_preprocess_sample
,设置如下环境变量:
# lib libraries that the mindspore depends on, modify "pip3" according to the actual situation
export LD_LIBRARY_PATH=`pip3 show mindspore-ascend | grep Location | awk '{print $2"/mindspore/lib"}' | xargs realpath`:${LD_LIBRARY_PATH}
执行cmake
命令,其中pip3
需要按照实际情况修改:
cmake . -DMINDSPORE_PATH=`pip3 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`
再执行make
命令编译即可。
make
编译成功后,会获得resnet50_sample
可执行文件。在工程目录ascend910_resnet50_preprocess_sample
下创建model
目录放置MindIR文件resnet50_imagenet.mindir。此外,创建test_data
目录用于存放待分类的图片,图片可来自ImageNet2012等各类开源数据集,输入执行命令即可获取推理结果:
./resnet50_sample
推理结果如下:
Image: ./test_data/ILSVRC2012_val_00002138.JPEG infer result: 0
Image: ./test_data/ILSVRC2012_val_00003014.JPEG infer result: 0
Image: ./test_data/ILSVRC2012_val_00006697.JPEG infer result: 0
Image: ./test_data/ILSVRC2012_val_00007197.JPEG infer result: 0
Image: ./test_data/ILSVRC2012_val_00009111.JPEG infer result: 0
Image: ./test_data/ILSVRC2012_val_00009191.JPEG infer result: 0
Image: ./test_data/ILSVRC2012_val_00009346.JPEG infer result: 0
Image: ./test_data/ILSVRC2012_val_00009379.JPEG infer result: 0
Image: ./test_data/ILSVRC2012_val_00009396.JPEG infer result: 0