使用MindConverter迁移脚本

image0image1image2

概述

PyTorch模型转换为MindSpore脚本和权重,首先需要将PyTorch模型导出为ONNX模型,然后使用MindConverter CLI工具进行脚本和权重迁移。 HuggingFace Transformers是PyTorch框架下主流的自然语言处理三方库,我们以Transformer中的BertForMaskedLM为例,演示迁移过程。

环境准备

本案例需安装以下Python三方库:

pip install torch==1.5.1
pip install transformers==4.2.2
pip install mindspore==1.2.0
pip install mindinsight==1.2.0
pip install onnx

以上安装命令可选用国内的清华源途径进行安装,可加快文件下载速度,即在上述命令后面添加-i https://pypi.tuna.tsinghua.edu.cn/simple

安装ONNX第三方库时,需要提前安装protobuf-compilerlibprotoc-dev,如果没有以上两个库,可以使用命令apt-get install protobuf-compiler libprotoc-dev进行安装。

ONNX模型导出

首先实例化HuggingFace中的BertForMaskedLM,以及相应的分词器(首次使用时需要下载模型权重、词表、模型配置等数据)。

关于HuggingFace的使用,本文不做过多介绍,详细使用请参考HuggingFace使用文档

该模型可对句子中被掩蔽(mask)的词进行预测。

[1]:
from transformers.models.bert import BertForMaskedLM, BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")

我们使用该模型进行推理,生成若干组测试用例,以验证模型迁移的正确性。

这里我们以一条句子为例china is a poworful country, its capital is beijing.

我们对beijing进行掩蔽(mask),输入china is a poworful country, its capital is [MASK].至模型,模型预期输出应为beijing

[2]:
import numpy as np
import torch

text = "china is a poworful country, its capital is [MASK]."
tokenized_sentence = tokenizer(text)

mask_idx = tokenized_sentence["input_ids"].index(tokenizer.convert_tokens_to_ids("[MASK]"))
input_ids = np.array([tokenized_sentence["input_ids"]])
attention_mask = np.array([tokenized_sentence["attention_mask"]])
token_type_ids = np.array([tokenized_sentence["token_type_ids"]])

# Get [MASK] token id.
print(f"MASK TOKEN id: {mask_idx}")
print(f"Tokens: {input_ids}")
print(f"Attention mask: {attention_mask}")
print(f"Token type ids: {token_type_ids}")

model.eval()
with torch.no_grad():
    predictions = model(input_ids=torch.tensor(input_ids),
                        attention_mask=torch.tensor(attention_mask),
                        token_type_ids=torch.tensor(token_type_ids))
    predicted_index = torch.argmax(predictions[0][0][mask_idx])
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
    print(f"Pred id: {predicted_index}")
    print(f"Pred token: {predicted_token}")
    assert predicted_token == "beijing"
MASK TOKEN id: 12
Tokens: [[  101  2859  2003  1037 23776 16347  5313  2406  1010  2049  3007  2003
    103  1012   102]]
Attention mask: [[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]]
Token type ids: [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
Pred id: 7211
Pred token: beijing

HuggingFace提供了导出ONNX模型的工具,可使用如下方法将HuggingFace的预训练模型导出为ONNX模型:

[3]:
from pathlib import Path
from transformers.convert_graph_to_onnx import convert

# Exported onnx model path.
saved_onnx_path = "./exported_bert_base_uncased/bert_base_uncased.onnx"
convert("pt", model, Path(saved_onnx_path), 11, tokenizer)
Creating folder exported_bert_base_uncased
Using framework PyTorch: 1.5.1+cu101
Found input input_ids with shape: {0: 'batch', 1: 'sequence'}
Found input token_type_ids with shape: {0: 'batch', 1: 'sequence'}
Found input attention_mask with shape: {0: 'batch', 1: 'sequence'}
Found output output_0 with shape: {0: 'batch', 1: 'sequence'}
Ensuring inputs are in correct order
position_ids is not present in the generated input list.
Generated inputs order: ['input_ids', 'attention_mask', 'token_type_ids']

根据打印的信息,我们可以看到导出的ONNX模型输入节点有3个:input_idstoken_type_idsattention_mask,以及相应的输入轴, 输出节点有一个output_0

至此ONNX模型导出成功,接下来对导出的ONNX模型精度进行验证(ONNX模型导出过程在ARM机器上执行,可能需要用户自行编译安装PyTorch以及Transformers三方库)。

ONNX模型验证

我们仍然使用PyTorch模型推理时的句子china is a poworful country, its capital is [MASK].作为输入,观测ONNX模型表现是否符合预期。

[4]:
import onnx
import onnxruntime as ort

model = onnx.load(saved_onnx_path)
sess = ort.InferenceSession(bytes(model.SerializeToString()))
result = sess.run(
    output_names=None,
    input_feed={"input_ids": input_ids,
                "attention_mask": attention_mask,
                "token_type_ids": token_type_ids}
)[0]
predicted_index = np.argmax(result[0][mask_idx])
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]

print(f"ONNX Pred id: {predicted_index}")
print(f"ONNX Pred token: {predicted_token}")
assert predicted_token == "beijing"
ONNX Pred id: 7211
ONNX Pred token: beijing

可以看到,导出的ONNX模型功能与原PyTorch模型完全一致,接下来可以使用MindConverter进行脚本和权重迁移了!

MindConverter进行模型脚本和权重迁移

MindConverter进行模型转换时,需要给定模型路径(--model_file)、输入节点(--input_nodes)、输入节点尺寸(--shape)、输出节点(--output_nodes)。

生成的脚本输出路径(--output)、转换报告路径(--report)为可选参数,默认为当前路径下的output目录,若输出目录不存在将自动创建。

[5]:
!mindconverter --model_file ./exported_bert_base_uncased/bert_base_uncased.onnx --shape 1,128 1,128 1,128  \
               --input_nodes input_ids token_type_ids attention_mask  \
               --output_nodes output_0  \
               --output ./converted_bert_base_uncased  \
               --report ./converted_bert_base_uncased

MindConverter: conversion is completed.

看到“MindConverter: conversion is completed.”即代表模型已成功转换!

转换完成后,该目录下生成如下文件: - 模型定义脚本(后缀为.py) - 权重ckpt文件(后缀为.ckpt) - 迁移前后权重映射(后缀为.json) - 转换报告(后缀为.txt)

通过ls命令检查一下转换结果。

[6]:
!ls ./converted_bert_base_uncased
bert_base_uncased.ckpt  report_of_bert_base_uncased.txt
bert_base_uncased.py    weight_map_of_bert_base_uncased.json

可以看到所有文件已生成。

迁移完成,接下来我们对迁移后模型精度进行验证。

MindSpore模型验证

我们仍然使用china is a poworful country, its capital is [MASK].作为输入,观测迁移后模型表现是否符合预期。

由于工具在转换时,需要将模型尺寸冻结,因此在使用MindSpore进行推理验证时,需要将句子补齐(Pad)到固定长度,可通过如下函数实现句子补齐。

推理时,句子长度需小于转换时的最大句长(这里我们最长句子长度为128,即在转换阶段通过--shape 1,128指定)。

[7]:
def padding(input_ids, attn_mask, token_type_ids, target_len=128):
    length = len(input_ids)
    for i in range(target_len - length):
        input_ids.append(0)
        attn_mask.append(0)
        token_type_ids.append(0)
    return np.array([input_ids]), np.array([attn_mask]), np.array([token_type_ids])
[8]:
from converted_bert_base_uncased.bert_base_uncased import Model as MsBert
from mindspore import load_checkpoint, load_param_into_net, context, Tensor


context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
padded_input_ids, padded_attention_mask, padded_token_type = padding(tokenized_sentence["input_ids"],
                                                                     tokenized_sentence["attention_mask"],
                                                                     tokenized_sentence["token_type_ids"],
                                                                     target_len=128)
padded_input_ids = Tensor(padded_input_ids)
padded_attention_mask = Tensor(padded_attention_mask)
padded_token_type = Tensor(padded_token_type)

model = MsBert()
param_dict = load_checkpoint("./converted_bert_base_uncased/bert_base_uncased.ckpt")
not_load_params = load_param_into_net(model, param_dict)
output = model(padded_attention_mask, padded_input_ids, padded_token_type)

assert not not_load_params

predicted_index = np.argmax(output.asnumpy()[0][mask_idx])
print(f"ONNX Pred id: {predicted_index}")
assert predicted_index == 7211
ONNX Pred id: 7211

至此,使用MindConverter进行脚本和权重迁移完成。

用户可根据使用场景编写训练、推理、部署脚本,实现个人业务逻辑。

常见问题

Q:如何修改迁移后脚本的批次大小(Batch size)、句子长度(Sequence length)等尺寸(shape)规格,以实现模型可支持任意尺寸的数据推理、训练?

A:迁移后脚本存在shape限制,通常是由于Reshape算子导致,或其他涉及张量排布变化的算子导致。以上述Bert迁移为例,首先创建两个全局变量,表示预期的批次大小、句子长度,而后修改Reshape操作的目标尺寸,替换成相应的批次大小、句子长度的全局变量即可。

Q:生成后的脚本中类名的定义不符合开发者的习惯,如``class Module0(nn.Cell)``,人工修改是否会影响转换后的权重加载?

A:权重的加载仅与变量名、类结构有关,因此类名可以修改,不影响权重加载。若需要调整类的结构,则相应的权重命名需要同步修改以适应迁移后模型的结构。