应用RoundToNearest后量化算法
RoundToNearest后量化算法简介
RoundToNearest算法是一类较朴素的后量化算法,其取整方式使用了Round to nearest,即四舍五入的方式。
当前金箍棒中的RoundToNearest后量化(后面使用RTN来简称)主要针对LLM(大语言模型场景),使用MinMax校正器对线性层(Linear)进行量化。伪量化的网络结构示意如下:
表1:RTN算法规格
规格 |
规格说明 |
---|---|
硬件支持 |
量化阶段运行在CPU,量化模型推理仅支持Ascend |
网络支持 |
Llama2 13B/70B,具体请参见Llama2网络 |
运行模式支持 |
Graph模式和PyNative模式 |
示例
跟金箍棒仓所有算法一样,RTN算法的应用主要可以分为两个阶段:量化阶段和部署阶段。
量化阶段是部署前提前完成的,主要的工作是:收集权重的分布、计算量化参数、量化权重数据、插入反量化节点。
部署阶段通常是指用户在生产环境,使用MindSpore框架对量化后的模型进行推理的过程。
本用例使用Llama2网络进行演示,主要分四个步骤:环境准备、模型量化、模型部署评估、效果分析。
步骤1. 环境准备
1.1. Ascend环境
RTN算法需要运行在Ascend硬件上,Ascend的环境配置可以参考MindSpore安装指南安装昇腾AI处理器配套软件包小节和配置环境变量小节。
1.2. MindSpore环境
金箍棒依赖于MindSpore,需要提前安装合适的MindSpore。可以从MindSpore官网下载预编译好的v2.3.0-rc1版本安装包并安装。
1.3. MindSpore Transformers环境
金箍棒依赖于MindSpore Transformers,需要提前安装合适的MindSpore Transformers。可以从MindSpore官网下载预编译好的v1.1.0-rc1版本安装包并安装。
1.4. 金箍棒环境
从MindSpore官网下载预编译好的MindSpore GoldenStick v0.4.0 版本安装包并安装。
1.5. 相关文件准备
需要预先下载MindSpore Transformers Llama2网络相关的文件以及评估使用的数据集,包括:wikitext2数据集和Llama2 7B网络相关文件。
第一步创建工作目录:
[1]:
!mkdir workspace
第二步准备数据集,由于权限限制,需要手动下载wikitext2数据集:
数据集下载地址:WikiText2数据集
下载好数据集后,需要将数据集文件拷贝到上一步中创建的workspace目录下,并确保数据集文件名为wikitext-2-v1.zip,然后运行解压代码:
[2]:
!cd workspace; unzip wikitext-2-v1.zip
Archive: wikitext-2-v1.zip
creating: wikitext-2/
inflating: wikitext-2/wiki.test.tokens
inflating: wikitext-2/wiki.valid.tokens
inflating: wikitext-2/wiki.train.tokens
第三步准备Llama2 7B网络checkpoint文件,Llama2分词器文件,Llama2模型配置文件:
[3]:
!cd workspace; wget --no-check-certificate -O llama2_7b.ckpt https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/MindFormers/llama2/llama2_7b.ckpt
!cd workspace; wget --no-check-certificate -O tokenizer.model https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/MindFormers/llama2/tokenizer.model
!cd workspace; cp ../configs/run_llama2_7b_910b.yaml ./
--2024-03-19 17:29:17-- https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/MindFormers/llama2/llama2_7b.ckpt
Length: 13476850247 (13G) [binary/octet-stream]
Saving to: ‘llama2_7b.ckpt’
llama2_7b.ckpt 100%[===================>] 12.55G 27.5MB/s in 7m 39s
2024-03-19 17:36:57 (28.0 MB/s) - ‘llama2_7b.ckpt’ saved [13476850247/13476850247]
--2024-03-19 17:36:57-- https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/MindFormers/llama2/tokenizer.model
Length: 499723 (488K) [binary/octet-stream]
Saving to: ‘tokenizer.model’
tokenizer.model 100%[===================>] 488.01K --.-KB/s in 0.1s
2024-03-19 17:36:57 (3.37 MB/s) - ‘tokenizer.model’ saved [499723/499723]
下载时如果遇到网络问题,可以尝试使用浏览器手动下载相应文件,并放到相应目录下
完成上述准备后,检查目录结构:
[4]:
!cd workspace; tree -L 2 -U
.
├── llama2_7b.ckpt
├── wikitext-2
│ ├── wiki.train.tokens
│ ├── wiki.test.tokens
│ └── wiki.valid.tokens
├── tokenizer.model
├── wikitext-2-v1.zip
└── run_llama2_7b_910b.yaml
1 directory, 7 files
步骤2. 模型量化
2.1. 实例化MindFormerConfig
构造MindSpore Transformers仓的Llama2网络,首先需要构造MindFormerConfig配置项,这里先实现一个创建MindFormerConfig的工具函数,并定义了一些常量:
[5]:
import mindspore as ms
from mindspore import context
from mindformers import LlamaForCausalLM, MindFormerConfig, LlamaConfig, init_context, TransformerOpParallelConfig
def _set_config(config_path, device_target, device_id):
"""setup MindFormerConfig"""
mfconfig = MindFormerConfig(config_path)
if device_id != -1:
mfconfig.context.device_id = device_id
mfconfig.context.device_target = device_target
mfconfig.model.model_config = LlamaConfig(**mfconfig.model.model_config)
init_context(use_parallel=mfconfig.use_parallel, context_config=mfconfig.context, parallel_config=mfconfig.parallel)
parallel_config = TransformerOpParallelConfig(**mfconfig.parallel_config)
mfconfig.model.model_config.parallel_config = parallel_config
mfconfig.model.model_config.checkpoint_name_or_path = mfconfig.load_checkpoint
return mfconfig
def create_mfconfig(config_path, device_target, device_id, bs, seq_len, tokenizer_path="", ckpt_path="", model_parallel=1):
"""Create mindformers config for llama2 network for example."""
if model_parallel > 1:
use_parallel = True
else:
use_parallel = False
model_parallel = 1
compute_dtype = ms.float16
config = _set_config(config_path, device_target, device_id)
config.model.model_config.batch_size = bs
config.model.model_config.seq_length = seq_len
config.model.model_config.compute_dtype = compute_dtype
config.model.model_config.layernorm_compute_type = ms.float32
config.model.model_config.softmax_compute_type = ms.float16
config.model.model_config.rotary_dtype = ms.float16
config.model.model_config.param_init_type = ms.float16
config.processor.tokenizer.vocab_file = tokenizer_path
config.load_checkpoint = ckpt_path
config.model.model_config.checkpoint_name_or_path = ckpt_path
config.use_parallel = use_parallel
config.parallel_config.model_parallel = model_parallel
return config
llama2_config_file = "./workspace/run_llama2_7b_910b.yaml"
llama2_w16a16_ckpt_file = "./workspace/llama2_7b.ckpt"
llama2_w8a16_ckpt_file = "./workspace/llama2-7b-w8a16.ckpt"
vocab_file = "./workspace/tokenizer.model"
wikitext2_ds_path = "./workspace/wikitext-2/wiki.valid.tokens"
bs = 1
seq_len = 256
device_id = 1
代码中的device_id变量可以根据运行环境中Ascend硬件空闲情况修改。
实例化一个MindFormerConfig对象:
[6]:
import mindspore as ms
from mindspore import context
context.set_context(device_target="CPU", mode=ms.GRAPH_MODE)
quant_network_config = create_mfconfig(llama2_config_file, "CPU", device_id, bs, seq_len, ckpt_path=llama2_w16a16_ckpt_file)
2.2. 实例化Llama2网络
[7]:
import mindspore as ms
from mindformers import LlamaForCausalLM
network = LlamaForCausalLM(quant_network_config.model.model_config)
network.set_train(False)
network.phase = 'predict'
2024-03-19 18:52:23,380 - mindformers[mindformers/version_control.py:62] - INFO - The Cell Reuse compilation acceleration feature is not supported when the environment variable ENABLE_CELL_REUSE is 0 or MindSpore version is earlier than 2.1.0 or stand_alone mode or pipeline_stages <= 1
2024-03-19 18:52:23,382 - mindformers[mindformers/version_control.py:66] - INFO -
The current ENABLE_CELL_REUSE=0, please set the environment variable as follows:
export ENABLE_CELL_REUSE=1 to enable the Cell Reuse compilation acceleration feature.
2024-03-19 18:52:23,383 - mindformers[mindformers/version_control.py:72] - INFO - The Cell Reuse compilation acceleration feature does not support single-card mode.This feature is disabled by default. ENABLE_CELL_REUSE=1 does not take effect.
2024-03-19 18:52:23,384 - mindformers[mindformers/version_control.py:75] - INFO - The Cell Reuse compilation acceleration feature only works in pipeline parallel mode(pipeline_stage>1).Current pipeline stage=1, the feature is disabled by default.
......
[WARNING] ME(1746443:281473469290880,MainProcess):2024-03-19-18:55:07.118.525 [mindspore/train/serialization.py:185] The type of model.layers.31.attention_norm.weight:Float16 in 'parameter_dict' is different from the type of it in 'net':Float32, then the type convert from Float16 to Float32 in the network.
[WARNING] ME(1746443:281473469290880,MainProcess):2024-03-19-18:55:07.123.086 [mindspore/train/serialization.py:185] The type of model.layers.31.ffn_norm.weight:Float16 in 'parameter_dict' is different from the type of it in 'net':Float32, then the type convert from Float16 to Float32 in the network.
[WARNING] ME(1746443:281473469290880,MainProcess):2024-03-19-18:55:07.751.733 [mindspore/train/serialization.py:185] The type of model.norm_out.weight:Float16 in 'parameter_dict' is different from the type of it in 'net':Float32, then the type convert from Float16 to Float32 in the network.
2024-03-19 18:55:08,105 - mindformers[mindformers/models/modeling_utils.py:1413] - INFO - weights in ./workspace/llama2_7b.ckpt are loaded
2.3. 实例化RTN算法
[8]:
from mindspore_gs.common import BackendTarget
from mindspore_gs.ptq import PTQConfig, PTQMode
from mindspore_gs.ptq import RoundToNearest as RTN
cfg = PTQConfig(mode=PTQMode.QUANTIZE, backend=BackendTarget.ASCEND)
ptq = RTN(config=cfg)
2.4. 量化Llama2网络并保存ckpt
[9]:
qnet = ptq.apply(network.model)
qnet = ptq.convert(qnet)
network.model = qnet
ms.save_checkpoint(network, llama2_w8a16_ckpt_file)
成功运行后,会在当前目录下生成llama2-7b-w8a16.ckpt
文件。
步骤3. 模型部署
3.1. 实例化MindFormerConfig和Llama2网络
[10]:
context.set_context(device_target="Ascend")
deploy_network_config = create_mfconfig(llama2_config_file, "Ascend", device_id, bs, seq_len)
deploy_network = LlamaForCausalLM(deploy_network_config.model.model_config)
deploy_network.set_train(False)
deploy_network.phase = 'predict'
2024-03-19 19:19:37,710 - mindformers[mindformers/version_control.py:62] - INFO - The Cell Reuse compilation acceleration feature is not supported when the environment variable ENABLE_CELL_REUSE is 0 or MindSpore version is earlier than 2.1.0 or stand_alone mode or pipeline_stages <= 1
2024-03-19 19:19:37,710 - mindformers[mindformers/version_control.py:66] - INFO -
The current ENABLE_CELL_REUSE=0, please set the environment variable as follows:
export ENABLE_CELL_REUSE=1 to enable the Cell Reuse compilation acceleration feature.
2024-03-19 19:19:37,711 - mindformers[mindformers/version_control.py:72] - INFO - The Cell Reuse compilation acceleration feature does not support single-card mode.This feature is disabled by default. ENABLE_CELL_REUSE=1 does not take effect.
2024-03-19 19:19:37,712 - mindformers[mindformers/version_control.py:75] - INFO - The Cell Reuse compilation acceleration feature only works in pipeline parallel mode(pipeline_stage>1).Current pipeline stage=1, the feature is disabled by default.
2024-03-19 19:21:07,859 - mindformers[mindformers/models/modeling_utils.py:1415] - INFO - model built, but weights is unloaded, since the config has no checkpoint_name_or_path attribute or checkpoint_name_or_path is None.
3.2. 加载量化后的ckpt
由于MindSpore当前不支持保存修改后的网络,所以在加载量化ckpt之前,需要先用算法恢复带量化结构的网络,然后再加载ckpt到网络。
[11]:
deploy_cfg = PTQConfig(mode=PTQMode.DEPLOY, backend=BackendTarget.ASCEND)
deploy_ptq = RTN(config=deploy_cfg)
deploy_network.model = deploy_ptq.apply(deploy_network.model)
deploy_network.model = deploy_ptq.convert(deploy_network.model)
ms.load_checkpoint(llama2_w8a16_ckpt_file, deploy_network)
[11]:
{'model.tok_embeddings.embedding_weight': Parameter (name=model.tok_embeddings.embedding_weight, shape=(32000, 4096), dtype=Float16, requires_grad=True),
......
'model.layers.31.feed_forward.w3._weight_quantizer.scale': Parameter (name=model.layers.31.feed_forward.w3._weight_quantizer.scale, shape=(11008,), dtype=Float16, requires_grad=True),
'model.layers.31.feed_forward.w3._weight_quantizer.zp_neg': Parameter (name=model.layers.31.feed_forward.w3._weight_quantizer.zp_neg, shape=(11008,), dtype=Float16, requires_grad=True),
'model.norm_out.weight': Parameter (name=model.norm_out.weight, shape=(4096,), dtype=Float32, requires_grad=True),
'lm_head.weight': Parameter (name=lm_head.weight, shape=(32000, 4096), dtype=Float16, requires_grad=True)}
3.3. 评估量化后的网络
本示例对Llama2在wikitext2数据集上评估Perplexity指标。使用步骤1中下载好的分词器和数据集文件分别实例化分词器对象和数据集对象,并实例化PerplexityMetric对象作为metric。
[12]:
import mindspore as ms
from mindformers import LlamaTokenizer
from mindformers.core.metric import PerplexityMetric
from mindspore_gs.datasets import create_wikitext_dataset
tokenizer = LlamaTokenizer(vocab_file=vocab_file)
deploy_ds = create_wikitext_dataset(wikitext2_ds_path, bs, seq_len, tokenizer)
deploy_metrics = {"PerplexityMetric": PerplexityMetric()}
deploy_model = ms.Model(deploy_network, metrics=deploy_metrics, eval_network=deploy_network)
quant_ppl = deploy_model.eval(deploy_ds, dataset_sink_mode=deploy_network_config.runner_config.sink_mode)
print(f"W8A16 Perplexity: {quant_ppl}")
[INFO] GE(1746443,python):2024-03-19-19:25:18.990.947 [ge_api.cc:523][status:INIT]1746443 AddGraph:Start to add graph in Session. graph_id: 1, graph_name: kernel_graph224, session_id: 0.
[INFO] GE(1746443,python):2024-03-19-19:25:18.991.481 [ge_api.cc:1154][status:INIT]1746443 CompileGraph:Start to compile graph, graph_id: 1
[INFO] GE(1746443,python):2024-03-19-19:25:18.991.586 [graph_manager.cc:1264][EVENT]1746443 PreRun:PreRun start: graph node size 1, session id 0, graph id 1, graph name kernel_graph224.
[INFO] GE(1746443,python):2024-03-19-19:25:19.065.657 [ge_api.cc:1160][status:STOP]1746443 CompileGraph:Compile graph success.
[INFO] GE(1746443,python):2024-03-19-19:25:19.067.797 [ge_api.cc:787][status:INIT]2453595 RunGraphWithStreamAsync:Session run graph with stream async, session_id: 0, graph_id: 1, input size 0, output size 0
[INFO] GE(1746443,python):2024-03-19-19:25:19.079.152 [ge_api.cc:799][status:STOP]2453595 RunGraphWithStreamAsync:Session run graph with stream async finished.
[INFO] GE(1746443,python):2024-03-19-19:26:40.520.923 [ge_api.cc:523][status:INIT]1746443 AddGraph:Start to add graph in Session. graph_id: 2, graph_name: kernel_graph225, session_id: 0.
[INFO] GE(1746443,python):2024-03-19-19:26:40.581.045 [ge_api.cc:1154][status:INIT]1746443 CompileGraph:Start to compile graph, graph_id: 2
[INFO] GE(1746443,python):2024-03-19-19:26:40.633.523 [graph_manager.cc:1264][EVENT]1746443 PreRun:PreRun start: graph node size 3025, session id 0, graph id 2, graph name kernel_graph225.
[INFO] GE(1746443,python):2024-03-19-19:28:24.659.856 [ge_api.cc:799][status:STOP]2453595 RunGraphWithStreamAsync:Session run graph with stream async finished.
[INFO] GE(1746443,python):2024-03-19-19:28:24.665.855 [ge_api.cc:787][status:INIT]2453595 RunGraphWithStreamAsync:Session run graph with stream async, session_id: 0, graph_id: 2, input size 739, output size 3
[INFO] GE(1746443,python):2024-03-19-19:28:24.667.497 [ge_api.cc:799][status:STOP]2453595 RunGraphWithStreamAsync:Session run graph with stream async finished.
[INFO] GE(1746443,python):2024-03-19-19:28:25.267.844 [ge_api.cc:787][status:INIT]2453595 RunGraphWithStreamAsync:Session run graph with stream async, session_id: 0, graph_id: 3, input size 3, output size 1
......
[INFO] GE(1746443,python):2024-03-19-19:29:18.708.299 [ge_api.cc:799][status:STOP]2453595 RunGraphWithStreamAsync:Session run graph with stream async finished.
W8A16 Perplexity: {'PerplexityMetric': {'loss': 2.910757654840339, 'PPL': 18.370711954412435}}
步骤4. 效果分析
4.1. 评估FP16网络的Perplexity指标
[13]:
import mindspore as ms
from mindformers import LlamaForCausalLM, LlamaTokenizer
from mindformers.core.metric import PerplexityMetric
from mindspore_gs.datasets import create_wikitext_dataset
fp16_network_config = create_mfconfig(llama2_config_file, "Ascend", device_id, bs, seq_len, ckpt_path=llama2_w16a16_ckpt_file)
fp16_network = LlamaForCausalLM(fp16_network_config.model.model_config)
fp16_network.set_train(False)
fp16_network.phase = 'predict'
tokenizer = LlamaTokenizer(vocab_file=vocab_file)
fp16_ds = create_wikitext_dataset(wikitext2_ds_path, bs, seq_len, tokenizer)
fp16_metrics = {"PerplexityMetric": PerplexityMetric()}
fp16_model = ms.Model(fp16_network, metrics=fp16_metrics, eval_network=fp16_network)
fp16_ppl = fp16_model.eval(fp16_ds, dataset_sink_mode=fp16_network_config.runner_config.sink_mode)
print(f"FP16 Perplexity: {fp16_ppl}")
......
[WARNING] ME(2617230:281472981539200,MainProcess):2024-03-19-19:41:51.626.9 [mindspore/train/serialization.py:185] The type of model.layers.31.attention_norm.weight:Float16 in 'parameter_dict' is different from the type of it in 'net':Float32, then the type convert from Float16 to Float32 in the network.
[WARNING] ME(2617230:281472981539200,MainProcess):2024-03-19-19:41:51.881.3 [mindspore/train/serialization.py:185] The type of model.layers.31.ffn_norm.weight:Float16 in 'parameter_dict' is different from the type of it in 'net':Float32, then the type convert from Float16 to Float32 in the network.
[WARNING] ME(2617230:281472981539200,MainProcess):2024-03-19-19:41:51.695.148 [mindspore/train/serialization.py:185] The type of model.norm_out.weight:Float16 in 'parameter_dict' is different from the type of it in 'net':Float32, then the type convert from Float16 to Float32 in the network.
2024-03-19 19:41:52,132 - mindformers[mindformers/models/modeling_utils.py:1413] - INFO - weights in ./workspace/llama2_7b.ckpt are loaded
[INFO] GE(2617230,python):2024-03-19-19:42:00.316.847 [ge_api.cc:523][status:INIT]2617230 AddGraph:Start to add graph in Session. graph_id: 1, graph_name: kernel_graph0, session_id: 0.
[INFO] GE(2617230,python):2024-03-19-19:42:00.317.200 [ge_api.cc:1154][status:INIT]2617230 CompileGraph:Start to compile graph, graph_id: 1
[INFO] GE(2617230,python):2024-03-19-19:42:00.317.282 [graph_manager.cc:1264][EVENT]2617230 PreRun:PreRun start: graph node size 1, session id 0, graph id 1, graph name kernel_graph0.
......
[INFO] GE(2617230,python):2024-03-19-19:43:17.424.380 [ge_api.cc:787][status:INIT]2654383 RunGraphWithStreamAsync:Session run graph with stream async, session_id: 0, graph_id: 2, input size 291, output size 3
[INFO] GE(2617230,python):2024-03-19-19:43:17.424.812 [ge_api.cc:799][status:STOP]2654383 RunGraphWithStreamAsync:Session run graph with stream async finished.
[INFO] GE(2617230,python):2024-03-19-19:43:17.464.158 [ge_api.cc:787][status:INIT]2654383 RunGraphWithStreamAsync:Session run graph with stream async, session_id: 0, graph_id: 3, input size 3, output size 1
[INFO] GE(2617230,python):2024-03-19-19:43:17.464.296 [ge_api.cc:799][status:STOP]2654383 RunGraphWithStreamAsync:Session run graph with stream async finished.
FP16 Perplexity: {'PerplexityMetric': {'loss': 2.909490694278072, 'PPL': 18.347451724873594}}
4.2. 比较结果
表2:Llama2 7B网络RTN算法量化前后对比
指标 |
FP16 |
W8A16 |
收益 |
---|---|---|---|
ckpt-size(GB)↓ |
13 |
7.1 |
-5.9 |
wikitext2-Perplexity↓ |
18.347 |
18.370 |
0.023 |
可以看到,经过RTN量化算法处理后:
量化后网络的参数量缩减了5.9GB,只剩下原Float16时的54.6%,即网络部署时,用于静态权重存储的显存下降到Float16时的54.6%。因而量化后的网络可以在资源更紧张的环境上部署,或者在相同的环境中提供更大的吞吐量。
量化后网络的在wikitext2数据集上的混淆度有极微小的上升,即量化后网络在wikitext2上生成式任务的效果几乎无损。