{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 应用RoundToNearest后量化算法\n",
    "\n",
    "[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.5.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/golden-stick/blob/r1.0.0/mindspore_gs/ptq/round_to_nearest/README_CN.ipynb)\n",
    "\n",
    "## RoundToNearest后量化算法简介\n",
    "\n",
    "RoundToNearest算法是一类较朴素的后量化算法,其取整方式使用了Round to nearest,即四舍五入的方式。\n",
    "\n",
    "当前金箍棒中的RoundToNearest后量化(后面使用RTN来简称)主要针对LLM(大语言模型场景),使用MinMax校正器对线性层(Linear)进行量化。伪量化的网络结构示意如下:\n",
    "\n",
    "![fakequantizer](images/zh_cn/rtn-fakequantizer.png)\n",
    "\n",
    "表1:RTN算法规格\n",
    "\n",
    "| 规格 | 规格说明 |\n",
    "| --- | --- |\n",
    "| 硬件支持 | 量化阶段运行在CPU,量化模型推理仅支持Ascend |\n",
    "| 网络支持 | Llama2系列网络,具体请参见[Llama2网络](https://gitee.com/mindspore/mindformers/tree/v1.3.2/mindformers/models/llama) |\n",
    "| 运行模式支持 | Graph模式和PyNative模式 |\n",
    "\n",
    "表2:网络使用RTN算法量化前后对比\n",
    "\n",
    "<table>\n",
    "<tr>\n",
    "    <th rowspan='2'>指标</th>\n",
    "    <th colspan='3'>llama2-7B</th>\n",
    "    <th colspan='3'>llama2-13B</th>\n",
    "    <th colspan='3'>llama2-70B</th>\n",
    "    <th colspan='3'>baichuan2-13B</th>\n",
    "    <th colspan='3'>chatGLM3-6B</th>\n",
    "</tr>\n",
    "<tr>\n",
    "    <th>FP16</th><th>W8A16</th><th>收益</th>\n",
    "    <th>FP16</th><th>W8A16</th><th>收益</th>\n",
    "    <th>FP16</th><th>W8A16</th><th>收益</th>\n",
    "    <th>FP16</th><th>W8A16</th><th>收益</th>\n",
    "    <th>FP16</th><th>W8A16</th><th>收益</th>\n",
    "</tr>\n",
    "<tr>\n",
    "    <td>ckpt-size(GB)↓</td>\n",
    "    <td>13</td><td>7.1</td><td>-45.38%</td>\n",
    "    <td>25</td><td>14</td><td>-44.00%</td>\n",
    "    <td>129</td><td>65</td><td>-49.61%</td>\n",
    "    <td>26</td><td>15</td><td>-42.31%</td>\n",
    "    <td>12</td><td>6.1</td><td>-49.17%</td>\n",
    "</tr>\n",
    "<tr>\n",
    "    <td>wikitext2-Perplexity↓</td>\n",
    "    <td>15.130</td><td>15.129</td><td>0.00</td>\n",
    "    <td>14.18</td><td>14.203</td><td bgcolor='#FA8072'>0.02</td>\n",
    "    <td>10.379</td><td>10.435</td><td bgcolor='#FA8072'>0.046</td>\n",
    "    <td>23.955</td><td>23.912</td><td>-0.043</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "</tr>\n",
    "<tr>\n",
    "    <td>squad1.1-F1↑</td>\n",
    "    <td>60.48</td><td>60.76</td><td>0.28</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "</tr>\n",
    "<tr>\n",
    "    <td>squad1.1-EM↑</td>\n",
    "    <td>39.62</td><td>39.57</td><td bgcolor='#FA8072'>-0.05</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "</tr>\n",
    "<tr>\n",
    "    <td>全量性能(tokens/s)</td>\n",
    "    <td>9.08</td><td>9.04</td><td>0</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "</tr>\n",
    "<tr>\n",
    "    <td>增量性能(tokens/s)</td>\n",
    "    <td>30.24</td><td>21.08</td><td bgcolor='#FA8072'>-30.29%</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "</tr>\n",
    "<tr>\n",
    "    <td>显存(GB)</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "    <td>27</td><td>16</td><td>-40.7%</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "    <td>-</td><td>-</td><td>-</td>\n",
    "</tr>\n",
    "</table>\n",
    "\n",
    "## 示例\n",
    "\n",
    "跟金箍棒仓所有算法一样,RTN算法的应用主要可以分为两个阶段:量化阶段和部署阶段。\n",
    "\n",
    "量化阶段是部署前提前完成的,主要的工作是:收集权重的分布、计算量化参数、量化权重数据、插入反量化节点。\n",
    "\n",
    "部署阶段通常是指用户在生产环境,使用MindSpore框架对量化后的模型进行推理的过程。\n",
    "\n",
    "本用例使用Llama2网络进行演示,主要分四个步骤:环境准备、模型量化、模型部署评估、效果分析。\n",
    "\n",
    "### 步骤1. 环境准备\n",
    "\n",
    "#### 1.1. Ascend环境\n",
    "\n",
    "RTN算法需要运行在Ascend硬件上,Ascend的环境配置可以参考[MindSpore安装指南](https://www.mindspore.cn/install)安装昇腾AI处理器配套软件包小节和配置环境变量小节。\n",
    "\n",
    "#### 1.2. MindSpore环境\n",
    "\n",
    "金箍棒依赖于MindSpore,需要提前安装合适的MindSpore。可以从MindSpore官网下载预编译好的[v2.3.1版本安装包](https://www.mindspore.cn/versions)并安装。\n",
    "\n",
    "#### 1.3. MindFormers环境\n",
    "\n",
    "金箍棒依赖于MindFormers,需要提前安装合适的MindFormers。可以从MindSpore官网下载预编译好的MindFormers[v1.2.0版本安装包](https://www.mindspore.cn/versions)并安装。\n",
    "\n",
    "#### 1.4. 金箍棒环境\n",
    "\n",
    "从MindSpore官网下载预编译好的[MindSpore GoldenStick v0.5.0 版本安装包](https://www.mindspore.cn/versions)并安装。\n",
    "\n",
    "#### 1.5. 相关文件准备\n",
    "\n",
    "需要预先下载MindSpore Transformers Llama2网络相关的文件以及评估使用的数据集,包括:wikitext2数据集和Llama2 7B网络相关文件。\n",
    "\n",
    "第一步创建工作目录:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "!mkdir workspace"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "第二步准备数据集:\n",
    "\n",
    "数据集下载地址:WikiText2数据集\n",
    "\n",
    "下载好数据集后,需要将数据集文件拷贝到上一步中创建的workspace目录下,并确保数据集文件名为wikitext-2-v1.zip,然后运行解压代码:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Archive:  wikitext-2-v1.zip\n",
      "   creating: wikitext-2/\n",
      "  inflating: wikitext-2/wiki.test.tokens  \n",
      "  inflating: wikitext-2/wiki.valid.tokens  \n",
      "  inflating: wikitext-2/wiki.train.tokens  \n"
     ]
    }
   ],
   "source": [
    "!cd workspace; unzip wikitext-2-v1.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "第三步准备Llama2 7B网络checkpoint文件,Llama2分词器文件,Llama2模型配置文件:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2024-03-19 17:29:17--  https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/MindFormers/llama2/llama2_7b.ckpt\n",
      "Length: 13476850247 (13G) [binary/octet-stream]\n",
      "Saving to: ‘llama2_7b.ckpt’\n",
      "\n",
      "llama2_7b.ckpt      100%[===================>]  12.55G  27.5MB/s    in 7m 39s  \n",
      "\n",
      "2024-03-19 17:36:57 (28.0 MB/s) - ‘llama2_7b.ckpt’ saved [13476850247/13476850247]\n",
      "\n",
      "--2024-03-19 17:36:57--  https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/MindFormers/llama2/tokenizer.model\n",
      "Length: 499723 (488K) [binary/octet-stream]\n",
      "Saving to: ‘tokenizer.model’\n",
      "\n",
      "tokenizer.model     100%[===================>] 488.01K  --.-KB/s    in 0.1s    \n",
      "\n",
      "2024-03-19 17:36:57 (3.37 MB/s) - ‘tokenizer.model’ saved [499723/499723]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!cd workspace; wget --no-check-certificate -O llama2_7b.ckpt https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/MindFormers/llama2/llama2_7b.ckpt\n",
    "!cd workspace; wget --no-check-certificate -O tokenizer.model https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/MindFormers/llama2/tokenizer.model\n",
    "!cd workspace; cp ../configs/predict_llama2_7b.yaml ./"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 下载时如果遇到网络问题,可以尝试使用浏览器手动下载相应文件,并放到相应目录下\n",
    "\n",
    "第四步修改predict_llama2_7b.yaml文件,将checkpoint、tokenizer的路径分别覆盖到yaml配置文件中load_checkpoint字段和processor章节的vocab_file字段。同时修改context章节的device_id为当前机器空闲的设备id。\n",
    "\n",
    "完成上述准备后,检查目录结构:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[01;34m.\u001b[00m\r\n",
      "├── llama2_7b.ckpt\r\n",
      "├── \u001b[01;34mwikitext-2\u001b[00m\r\n",
      "│   ├── wiki.train.tokens\r\n",
      "│   ├── wiki.test.tokens\r\n",
      "│   └── wiki.valid.tokens\r\n",
      "├── tokenizer.model\r\n",
      "├── \u001b[01;31mwikitext-2-v1.zip\u001b[00m\r\n",
      "└── predict_llama2_7b.yaml\r\n",
      "\r\n",
      "1 directory, 7 files\r\n"
     ]
    }
   ],
   "source": [
    "!cd workspace; tree -L 2 -U"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 步骤2. 模型量化\n",
    "\n",
    "构造MindSpore Transformers仓的Llama2网络,然后使用金箍棒RoundToNearest算法对网络进行量化,最终保存量化后的checkpoint文件:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------- Quantize-ing network...\n",
      "2024-03-19 19:19:37,711 - mindformers[mindformers/models/llama/llama.py:359] - INFO - Predict run mode:True\n",
      "------------------------- Quantize-ing network...\n",
      "2024-03-19 19:37:11,103 - mindformers[mindformers/generation/text_generator.py:914] - INFO - total time: 53.80364537239075 s; generated tokens: 1 tokens; generate speed: 0.018586101240514612 tokens/save_checkpoint\n",
      "------------------------- Saving checkpoint...\n",
      "------------------------- Checkpoint saved to ./output/w8a16_ckpt/rank_0/..."
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "import mindspore as ms\n",
    "from mindformers import LlamaForCausalLM, MindFormerConfig, LlamaConfig, init_context\n",
    "from mindspore_gs.ptq import PTQMode, PTQConfig\n",
    "from mindspore_gs.common import BackendTarget, logger\n",
    "from mindspore_gs.ptq import RoundToNearest as RTN\n",
    "from mindspore_gs.ptq.network_helpers.mf_net_helpers import MFLlama2Helper\n",
    "\n",
    "\n",
    "class Llama2Network:\n",
    "    \"\"\"Llama2Network.\"\"\"\n",
    "    @staticmethod\n",
    "    def create_mfconfig(config_path):\n",
    "        \"\"\"Create mindformers config for llama2 network for example.\"\"\"\n",
    "        config = MindFormerConfig(config_path)\n",
    "        config.model.model_config = LlamaConfig(**config.model.model_config)\n",
    "        init_context(use_parallel=config.use_parallel, context_config=config.context, parallel_config=config.parallel)\n",
    "        return config\n",
    "\n",
    "    @staticmethod\n",
    "    def create_network(mindformers_config):\n",
    "        network = LlamaForCausalLM(mindformers_config.model.model_config)\n",
    "        network.set_train(False)\n",
    "        network.phase = 'predict'\n",
    "        return network\n",
    "\n",
    "\n",
    "def quant_network(net: LlamaForCausalLM, mode=PTQMode.QUANTIZE, backend=BackendTarget.ASCEND, **kwargs):\n",
    "    \"\"\"Quant llama2 model to w8a16 with RTN algorithm.\"\"\"\n",
    "    start_time = time.time()\n",
    "    if mode == PTQMode.QUANTIZE:\n",
    "        logger.info(\"Use RTN algo to quant network and weight.\")\n",
    "    else:\n",
    "        logger.info(\"Use RTN algo to quant network.\")\n",
    "    cfg = PTQConfig(mode=mode, backend=backend, opname_blacklist=[\"lm_head\"])\n",
    "    ptq = RTN(config=cfg)\n",
    "    logger.info(f'Create PTQ cost time is {time.time() - start_time} s.')\n",
    "    start_time = time.time()\n",
    "    mfconfig = kwargs.get(\"mfconfig\", None)\n",
    "    if not mfconfig:\n",
    "        raise ValueError(\"Please provide mfconfig for calibrating.\")\n",
    "    network_helper = MFLlama2Helper(mfconfig)\n",
    "    net = ptq.apply(net, network_helper)\n",
    "    logger.info(f'Apply PTQ cost time is {time.time() - start_time} s.')\n",
    "    start_time = time.time()\n",
    "    net.phase = \"quant_convert\"\n",
    "    net = ptq.convert(net)\n",
    "    logger.info(f'Convert to real quantize cost time is {time.time() - start_time} s.')\n",
    "    return net\n",
    "\n",
    "\n",
    "start = time.time()\n",
    "print('------------------------- Creating network...', flush=True)\n",
    "net_mgr: Llama2Network = Llama2Network()\n",
    "config = net_mgr.create_mfconfig(\"./workspace/predict_llama2_7b.yaml\")\n",
    "network = net_mgr.create_network(config)\n",
    "logger.info(f'Create Network cost time is {time.time() - start} s.')\n",
    "start = time.time()\n",
    "ckpt_path = config.load_checkpoint\n",
    "logger.info(f'Loading ckpt :{ckpt_path}.')\n",
    "ms.load_checkpoint(ckpt_path, network)\n",
    "ms.ms_memory_recycle()\n",
    "logger.info(f'Load ckpt cost time is {time.time() - start} s.')\n",
    "print('------------------------- Quantize-ing network...', flush=True)\n",
    "start = time.time()\n",
    "network = quant_network(network, mode=PTQMode.QUANTIZE, backend=BackendTarget.ASCEND, mfconfig=config)\n",
    "logger.info(f'Quant Network cost time is {time.time() - start} s.')\n",
    "print('------------------------- Saving checkpoint...', flush=True)\n",
    "start = time.time()\n",
    "save_ckpt_path = os.path.join(config.output_dir, \"w8a16_ckpt\")\n",
    "save_path = os.path.join(save_ckpt_path, f\"rank_0\")\n",
    "os.makedirs(save_path, exist_ok=True)\n",
    "ms.save_checkpoint(network.parameters_dict(), os.path.join(save_path, \"w8a16.ckpt\"),\n",
    "                   choice_func=lambda x: \"key_cache\" not in x and \"value_cache\" not in x)\n",
    "logger.info(f'Save checkpoint cost time is {time.time() - start} s.')\n",
    "print(f'------------------------- Checkpoint saved to {save_path}...', flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "成功运行后,量化后的checkpoint文件会保存在 `./output/w8a16_ckpt/rank_0/w8a16.ckpt` 路径下。\n",
    "\n",
    "### 步骤3. 模型部署\n",
    "\n",
    "#### 3.1. 评估FP16网络的Perplexity指标\n",
    "\n",
    "使用WikiText2数据集评估Llama2-7B网络的混淆度指标。在评估前,需要注意环境中没有配置RUN_MODE环境变量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "......\n",
      "2024-03-19 19:41:52,132 - mindformers[mindformers/models/modeling_utils.py:1413] - INFO - weights in ./workspace/llama2_7b.ckpt are loaded\n",
      "[INFO] GE(2617230,python):2024-03-19-19:42:00.316.847 [ge_api.cc:523][status:INIT]2617230 AddGraph:Start to add graph in Session. graph_id: 1, graph_name: kernel_graph0, session_id: 0.\n",
      "[INFO] GE(2617230,python):2024-03-19-19:42:00.317.200 [ge_api.cc:1154][status:INIT]2617230 CompileGraph:Start to compile graph, graph_id: 1\n",
      "[INFO] GE(2617230,python):2024-03-19-19:42:00.317.282 [graph_manager.cc:1264][EVENT]2617230 PreRun:PreRun start: graph node size 1, session id 0, graph id 1, graph name kernel_graph0.\n",
      "......\n",
      "[INFO] GE(2617230,python):2024-03-19-19:43:17.424.380 [ge_api.cc:787][status:INIT]2654383 RunGraphWithStreamAsync:Session run graph with stream async, session_id: 0, graph_id: 2, input size 291, output size 3\n",
      "[INFO] GE(2617230,python):2024-03-19-19:43:17.424.812 [ge_api.cc:799][status:STOP]2654383 RunGraphWithStreamAsync:Session run graph with stream async finished.\n",
      "[INFO] GE(2617230,python):2024-03-19-19:43:17.464.158 [ge_api.cc:787][status:INIT]2654383 RunGraphWithStreamAsync:Session run graph with stream async, session_id: 0, graph_id: 3, input size 3, output size 1\n",
      "[INFO] GE(2617230,python):2024-03-19-19:43:17.464.296 [ge_api.cc:799][status:STOP]2654383 RunGraphWithStreamAsync:Session run graph with stream async finished.\n",
      "FP16 PPL: {'PerplexityMetric': {'loss': 2.247190694278072, 'PPL': 9.460511724873594}}\n"
     ]
    }
   ],
   "source": [
    "import mindspore as ms\n",
    "from mindformers.core.metric import PerplexityMetric\n",
    "from mindspore_gs.datasets import create_wikitext_dataset\n",
    "from mindspore_gs.ptq.network_helpers.mf_net_helpers import MFLlama2Helper\n",
    "\n",
    "\n",
    "net_mgr: Llama2Network = Llama2Network()\n",
    "fp16_network_config = net_mgr.create_mfconfig(\"./workspace/predict_llama2_7b.yaml\")\n",
    "fp16_network_config.model.model_config.use_past = False\n",
    "pad_token_id = fp16_network_config.model.model_config.pad_token_id\n",
    "fp16_network = net_mgr.create_network(fp16_network_config)\n",
    "ms.load_checkpoint(fp16_network_config.load_checkpoint, fp16_network)\n",
    "\n",
    "\n",
    "net_helper = MFLlama2Helper(fp16_network_config)\n",
    "bs = net_helper.get_spec(\"batch_size\")\n",
    "seq_len = net_helper.get_spec(\"seq_length\")\n",
    "tokenizer = net_helper.create_tokenizer()\n",
    "\n",
    "fp16_ds = create_wikitext_dataset(\"./workspace/wikitext-2/wiki.valid.tokens\", bs, seq_len, max_new_tokens=1, tokenizer=tokenizer)\n",
    "metric = PerplexityMetric()\n",
    "metric.clear()\n",
    "data_count = 0\n",
    "total_count = fp16_ds.get_dataset_size()\n",
    "for _, ds_item in enumerate(fp16_ds.create_dict_iterator()):\n",
    "    data_count += 1\n",
    "    logger.info(f\"Dataset count: {data_count}/{total_count}\")\n",
    "    input_ids = ds_item['input_ids'].asnumpy()\n",
    "    net_inputs = net_helper.assemble_inputs(input_ids)\n",
    "    outputs = fp16_network(*net_inputs)\n",
    "    metric.update(*outputs)\n",
    "print('...........Evaluate Over!...............', flush=True)\n",
    "print(f\"FP16 PPL: {metric.eval()}\", flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.2. 实例化非量化Llama2网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2024-03-19 19:19:37,710 - mindformers[mindformers/version_control.py:62] - INFO - The Cell Reuse compilation acceleration feature is not supported when the environment variable ENABLE_CELL_REUSE is 0 or MindSpore version is earlier than 2.1.0 or stand_alone mode or pipeline_stages <= 1\n",
      "2024-03-19 19:19:37,710 - mindformers[mindformers/version_control.py:66] - INFO - \n",
      "The current ENABLE_CELL_REUSE=0, please set the environment variable as follows: \n",
      "export ENABLE_CELL_REUSE=1 to enable the Cell Reuse compilation acceleration feature.\n",
      "2024-03-19 19:19:37,711 - mindformers[mindformers/version_control.py:72] - INFO - The Cell Reuse compilation acceleration feature does not support single-card mode.This feature is disabled by default. ENABLE_CELL_REUSE=1 does not take effect.\n",
      "2024-03-19 19:19:37,712 - mindformers[mindformers/version_control.py:75] - INFO - The Cell Reuse compilation acceleration feature only works in pipeline parallel mode(pipeline_stage>1).Current pipeline stage=1, the feature is disabled by default.\n",
      "2024-03-19 19:21:07,859 - mindformers[mindformers/models/modeling_utils.py:1415] - INFO - model built, but weights is unloaded, since the config has no checkpoint_name_or_path attribute or checkpoint_name_or_path is None.\n"
     ]
    }
   ],
   "source": [
    "import mindspore as ms\n",
    "from mindformers.core.metric import PerplexityMetric\n",
    "from mindspore_gs.datasets import create_wikitext_dataset\n",
    "from mindspore_gs.ptq.network_helpers.mf_net_helpers import MFLlama2Helper\n",
    "\n",
    "\n",
    "net_mgr: Llama2Network = Llama2Network()\n",
    "network_config = net_mgr.create_mfconfig(\"./workspace/predict_llama2_7b.yaml\")\n",
    "network_config.model.model_config.use_past = False\n",
    "pad_token_id = network_config.model.model_config.pad_token_id\n",
    "w8a16_network = net_mgr.create_network(network_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.3. 加载量化后的ckpt\n",
    "\n",
    "由于MindSpore当前不支持保存修改后的网络,所以在加载量化ckpt之前,需要先用算法恢复带量化结构的网络,然后再加载ckpt到网络。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'model.tok_embeddings.embedding_weight': Parameter (name=model.tok_embeddings.embedding_weight, shape=(32000, 4096), dtype=Float16, requires_grad=True),\n",
       " ......\n",
       " 'model.layers.31.feed_forward.w3._weight_quantizer.scale': Parameter (name=model.layers.31.feed_forward.w3._weight_quantizer.scale, shape=(11008,), dtype=Float16, requires_grad=True),\n",
       " 'model.layers.31.feed_forward.w3._weight_quantizer.zp_neg': Parameter (name=model.layers.31.feed_forward.w3._weight_quantizer.zp_neg, shape=(11008,), dtype=Float16, requires_grad=True),\n",
       " 'model.norm_out.weight': Parameter (name=model.norm_out.weight, shape=(4096,), dtype=Float32, requires_grad=True),\n",
       " 'lm_head.weight': Parameter (name=lm_head.weight, shape=(32000, 4096), dtype=Float16, requires_grad=True)}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "deploy_cfg = PTQConfig(mode=PTQMode.DEPLOY, backend=BackendTarget.ASCEND, opname_blacklist=[\"lm_head\"])\n",
    "deploy_ptq = RTN(config=deploy_cfg)\n",
    "w8a16_network = deploy_ptq.apply(w8a16_network)\n",
    "w8a16_network = deploy_ptq.convert(w8a16_network)\n",
    "ms.load_checkpoint(\"./output/w8a16_ckpt/rank_0/w8a16.ckpt\", w8a16_network)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.4. 评估量化后的网络\n",
    "\n",
    "本示例对Llama2在wikitext2数据集上评估Perplexity指标。使用步骤1中下载好的分词器和数据集文件分别实例化分词器对象和数据集对象,并实例化PerplexityMetric对象作为metric。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] GE(1746443,python):2024-03-19-19:25:18.990.947 [ge_api.cc:523][status:INIT]1746443 AddGraph:Start to add graph in Session. graph_id: 1, graph_name: kernel_graph224, session_id: 0.\n",
      "[INFO] GE(1746443,python):2024-03-19-19:25:18.991.481 [ge_api.cc:1154][status:INIT]1746443 CompileGraph:Start to compile graph, graph_id: 1\n",
      "[INFO] GE(1746443,python):2024-03-19-19:25:18.991.586 [graph_manager.cc:1264][EVENT]1746443 PreRun:PreRun start: graph node size 1, session id 0, graph id 1, graph name kernel_graph224.\n",
      "[INFO] GE(1746443,python):2024-03-19-19:25:19.065.657 [ge_api.cc:1160][status:STOP]1746443 CompileGraph:Compile graph success.\n",
      "[INFO] GE(1746443,python):2024-03-19-19:25:19.067.797 [ge_api.cc:787][status:INIT]2453595 RunGraphWithStreamAsync:Session run graph with stream async, session_id: 0, graph_id: 1, input size 0, output size 0\n",
      "[INFO] GE(1746443,python):2024-03-19-19:25:19.079.152 [ge_api.cc:799][status:STOP]2453595 RunGraphWithStreamAsync:Session run graph with stream async finished.\n",
      "[INFO] GE(1746443,python):2024-03-19-19:26:40.520.923 [ge_api.cc:523][status:INIT]1746443 AddGraph:Start to add graph in Session. graph_id: 2, graph_name: kernel_graph225, session_id: 0.\n",
      "[INFO] GE(1746443,python):2024-03-19-19:26:40.581.045 [ge_api.cc:1154][status:INIT]1746443 CompileGraph:Start to compile graph, graph_id: 2\n",
      "[INFO] GE(1746443,python):2024-03-19-19:26:40.633.523 [graph_manager.cc:1264][EVENT]1746443 PreRun:PreRun start: graph node size 3025, session id 0, graph id 2, graph name kernel_graph225.\n",
      "[INFO] GE(1746443,python):2024-03-19-19:28:24.659.856 [ge_api.cc:799][status:STOP]2453595 RunGraphWithStreamAsync:Session run graph with stream async finished.\n",
      "[INFO] GE(1746443,python):2024-03-19-19:28:24.665.855 [ge_api.cc:787][status:INIT]2453595 RunGraphWithStreamAsync:Session run graph with stream async, session_id: 0, graph_id: 2, input size 739, output size 3\n",
      "[INFO] GE(1746443,python):2024-03-19-19:28:24.667.497 [ge_api.cc:799][status:STOP]2453595 RunGraphWithStreamAsync:Session run graph with stream async finished.\n",
      "[INFO] GE(1746443,python):2024-03-19-19:28:25.267.844 [ge_api.cc:787][status:INIT]2453595 RunGraphWithStreamAsync:Session run graph with stream async, session_id: 0, graph_id: 3, input size 3, output size 1\n",
      "......\n",
      "[INFO] GE(1746443,python):2024-03-19-19:29:18.708.299 [ge_api.cc:799][status:STOP]2453595 RunGraphWithStreamAsync:Session run graph with stream async finished.\n",
      "W8A16 PPL: {'PerplexityMetric': {'loss': 2.237087654840339, 'PPL': 9.365711954412435}}\n"
     ]
    }
   ],
   "source": [
    "net_helper = MFLlama2Helper(network_config)\n",
    "bs = net_helper.get_spec(\"batch_size\")\n",
    "seq_len = net_helper.get_spec(\"seq_length\")\n",
    "tokenizer = net_helper.create_tokenizer()\n",
    "\n",
    "ds = create_wikitext_dataset(\"./workspace/wikitext-2/wiki.valid.tokens\", bs, seq_len, max_new_tokens=1, tokenizer=tokenizer)\n",
    "metric = PerplexityMetric()\n",
    "metric.clear()\n",
    "data_count = 0\n",
    "total_count = ds.get_dataset_size()\n",
    "for _, ds_item in enumerate(ds.create_dict_iterator()):\n",
    "    data_count += 1\n",
    "    logger.info(f\"Dataset count: {data_count}/{total_count}\")\n",
    "    input_ids = ds_item['input_ids'].asnumpy()\n",
    "    net_inputs = net_helper.assemble_inputs(input_ids)\n",
    "    outputs = w8a16_network(*net_inputs)\n",
    "    metric.update(*outputs)\n",
    "print('...........Evaluate Over!...............', flush=True)\n",
    "print(f\"W8A16 PPL: {metric.eval()}\", flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 步骤4. 效果分析\n",
    "\n",
    "表3:Llama2 7B网络RTN算法量化前后对比\n",
    "\n",
    "| 指标 | FP16 | W8A16 | 收益 |\n",
    "| --- | --- | --- | --- |\n",
    "| ckpt-size(GB)↓ | 13 | 7.1 | -5.9 |\n",
    "| wikitext2-Perplexity↓ | 9.46 | 9.37 | -0.09 |\n",
    "\n",
    "可以看到,经过RTN量化算法处理后:\n",
    "\n",
    "1. 量化后网络的参数量缩减了5.9GB,只剩下原Float16时的54.6%,即网络部署时,用于静态权重存储的显存下降到Float16时的54.6%。因而量化后的网络可以在资源更紧张的环境上部署,或者在相同的环境中提供更大的吞吐量。\n",
    "2. 量化后网络的在wikitext2数据集上的混淆度上降0.09,即量化后网络在wikitext2上生成式任务效果更好。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37-0308",
   "language": "python",
   "name": "py37-0308"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}