{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "77eddb36",
   "metadata": {},
   "source": [
    "# 单节点数据缓存\n",
    "\n",
    "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/tutorials/experts/zh_cn/dataset/mindspore_cache.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/tutorials/experts/zh_cn/dataset/mindspore_cache.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.2/tutorials/experts/source_zh_cn/dataset/cache.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41f30bb8",
   "metadata": {},
   "source": [
    "对于需要重复访问远程的数据集或需要重复从磁盘中读取数据集的情况,可以使用单节点缓存将数据集缓存于本地内存中,以加速数据集的读取。\n",
    "\n",
    "缓存算子依赖于在当前节点启动的缓存服务器,缓存服务器作为守护进程独立于用户的训练脚本而存在,主要用于提供缓存数据的管理,支持包括存储、查找、读取以及发生缓存未命中时对于缓存数据的写入等操作。\n",
    "\n",
    "若用户的内存空间不足以缓存所有数据集,则用户可以配置缓存算子使其将剩余数据缓存至磁盘。\n",
    "\n",
    "目前,缓存服务只支持单节点缓存,即客户端和服务器均在同一台机器上。该服务支持以下两类使用场景:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "628faa14",
   "metadata": {},
   "source": [
    "- 缓存加载好的原始数据集\n",
    "\n",
    "    用户可以在数据集加载操作中使用缓存。这将把加载完成的数据存到缓存服务器中,后续若需相同数据则可直接从中读取,避免从磁盘中重复加载。\n",
    "\n",
    "    ![cache on leaf pipeline](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/tutorials/experts/source_zh_cn/dataset/images/cache_dataset.png)\n",
    "- 缓存经过数据增强处理后的数据\n",
    "\n",
    "    用户也可在`map`操作中使用缓存。这将允许直接缓存数据增强(如图像裁剪、缩放等)处理后的数据,避免数据增强操作重复进行,减少了不必要的计算量。\n",
    "\n",
    "    ![cache on map pipeline](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/tutorials/experts/source_zh_cn/dataset/images/cache_processed_data.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06760f89",
   "metadata": {},
   "source": [
    "## 数据缓存流程\n",
    "\n",
    "使用缓存服务前,需要安装MindSpore,并设置相关环境变量。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4aab922b",
   "metadata": {},
   "source": [
    "> 目前数据缓存只能在linux环境下执行,Ubuntu、EulerOS以及CentOS均可参考[相关教程](https://help.ubuntu.com/community/SwapFaq#How_do_I_add_a_swap_file.3F),了解如何增大交换内存空间。此外,由于使用缓存可能会造成服务器的内存紧张,因此建议用户在使用缓存前增大服务器的交换内存空间至100GB以上。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69f1d690",
   "metadata": {},
   "source": [
    "### 1. 启动缓存服务器\n",
    "\n",
    "在使用单节点缓存服务之前,首先需要在命令行输入以下命令,启动缓存服务器:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d251efcf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cache server startup completed successfully!\n",
      "The cache server daemon has been created as process id 14678 and listening on port 50052.\n",
      "\n",
      "Recommendation:\n",
      "Since the server is detached into its own daemon process, monitor the server logs (under /tmp/mindspore/cache/log) for any issues that may happen after startup\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!cache_admin --start"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14dc4c56",
   "metadata": {},
   "source": [
    "若输出以上信息,则表示缓存服务器启动成功。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "012c4ae4",
   "metadata": {},
   "source": [
    "以上命令均可使用`-h`和`-p`参数来指定服务器,用户也可通过配置环境变量`MS_CACHE_HOST`和`MS_CACHE_PORT`来指定。若未指定则默认对ip为127.0.0.1且端口号为50052的服务器执行操作。\n",
    "\n",
    "可通过`ps -ef|grep cache_server`命令来检查服务器是否已启动以及查询服务器参数。\n",
    "\n",
    "也可通过`cache_admin --server_info`命令查看服务器的详细参数列表。\n",
    "\n",
    "若要启用数据溢出功能,则用户在启动缓存服务器时必须使用`-s`参数对溢出路径进行设置,否则该功能默认关闭。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d2ef5777",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cache Server Configuration: \n",
      "----------------------------------------\n",
      "         config name          value\n",
      "----------------------------------------\n",
      "            hostname      127.0.0.1\n",
      "                port          50052\n",
      "   number of workers              8\n",
      "           log level              1\n",
      "           spill dir           None\n",
      "----------------------------------------\n",
      "Active sessions: \n",
      "No active sessions.\n"
     ]
    }
   ],
   "source": [
    "!cache_admin --server_info"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33923b67",
   "metadata": {},
   "source": [
    "其中,Cache Server Configuration表格分别列出了当前服务器的IP地址、端口号、工作线程数、日志等级、溢出路径等详细配置信息。Active sessions模块展示了当前服务器中已启用的session ID列表。\n",
    "\n",
    "缓存服务器日志文件的命名格式为 \"cache_server.\\<主机名\\>.\\<用户名\\>.log.\\<日志等级\\>.\\<日期-时间\\>.\\<进程号\\>\"。\n",
    "\n",
    "当`GLOG_v=0`时,可能会屏显有DEBUG日志。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b4370bb",
   "metadata": {},
   "source": [
    "### 2. 创建缓存会话\n",
    "\n",
    "若缓存服务器中不存在缓存会话,则需要创建一个缓存会话,得到缓存会话id:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4b9e5d94",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Session created for server on port 50052: 780643335\n"
     ]
    }
   ],
   "source": [
    "!cache_admin -g"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfc8ce69",
   "metadata": {},
   "source": [
    "其中780643335为端口50052的服务器分配的缓存会话id,缓存会话id由服务器分配。\n",
    "\n",
    "通过`cache_admin --list_sessions`命令可以查看当前服务器中现存的所有缓存会话信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c1c64a7b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Listing sessions for server on port 50052\n",
      "\n",
      "     Session    Cache Id  Mem cached Disk cached  Avg cache size  Numa hit\n",
      "   780643335         n/a         n/a         n/a             n/a       n/a\n"
     ]
    }
   ],
   "source": [
    "!cache_admin --list_sessions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed8116b6",
   "metadata": {},
   "source": [
    "输出参数说明:\n",
    "\n",
    "- `Session`: 缓存会话id。\n",
    "- `Cache Id`: 当前缓存会话中的cache实例id,`n/a`表示当前尚未创建缓存实例。\n",
    "- `Mem cached`: 缓存在内存中的数据量。\n",
    "- `Disk cached`: 缓存在磁盘中的数据量。\n",
    "- `Avg cache size`:当前缓存的每行数据的平均大小。\n",
    "- `Numa hit`:Numa命中数,该值越高将获得越好的时间性能。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "817c8745",
   "metadata": {},
   "source": [
    "### 3. 创建缓存实例\n",
    "\n",
    "在Python训练脚本中使用`DatasetCache` 来定义一个名为`test_cache`的缓存实例,并把上一步中创建的缓存会话id传入`session_id`参数:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e7b7f10e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import mindspore.dataset as ds\n",
    "\n",
    "# define a variable named `session_id` to receive the cache session ID created in the previous step\n",
    "session_id = int(os.popen('cache_admin --list_sessions | tail -1 | awk -F \" \" \\'{{print $1;}}\\'').read())\n",
    "test_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db80459a",
   "metadata": {},
   "source": [
    "`DatasetCache`支持以下参数:\n",
    "\n",
    "- `session_id`:缓存会话的id,通过`cache_admin -g`命令来创建并获取。\n",
    "- `size`:缓存最大内存空间占用,该参数以MB为单位,例如512GB的缓存空间应设置`size=524288`,默认为0。\n",
    "- `spilling`:当内存空间超出所设置的最大内存空间占用时,是否允许将剩余的数据溢出至磁盘,默认为False。\n",
    "- `hostname`:连接至缓存服务器的ip地址,默认为127.0.0.1。\n",
    "- `port`:连接至缓存服务器的端口号,默认为50052。\n",
    "- `num_connections`:建立的TCP/IP连接数,默认为12。\n",
    "- `prefetch_size`:每次预取的数据行数,默认为20。\n",
    "\n",
    "其中值得注意的事情是:\n",
    "\n",
    "在实际使用中,通常应当首先使用`cache_admin -g`命令从缓存服务器处获得一个缓存会话id并作为`session_id`的参数,防止发生缓存会话不存在而报错的情况。\n",
    "\n",
    "设置`size=0`代表不限制缓存所使用的内存空间,缓存服务器会根据系统的内存资源状况,自动控制缓存服务器的内存空间占用,使其不超过系统总内存的80%。\n",
    "\n",
    "用户也可以根据机器本身的空闲内存大小,给`size`参数设置一个合理的取值。注意,当用户自主设置`size`参数时,要先确认系统可用内存和待加载数据集大小,若cache_server的内存空间占用或待加载数据集空间占耗超过系统可用内存时,有可能导致机器宕机/重启、cache_server自动关闭、训练流程执行失败等问题。\n",
    "\n",
    "若设置`spilling=True`,则当内存空间不足时,多余数据将写入磁盘中。因此,用户需确保所设置的磁盘路径具有写入权限以及足够的磁盘空间,以存储溢出至磁盘的缓存数据。注意,若启动服务器时未指定溢出路径,则在调用API时设置`spilling=True`将会导致报错。\n",
    "\n",
    "若设置`spilling=False`,则缓存服务器在耗尽所设置的内存空间后将不再写入新的数据。\n",
    "\n",
    "当使用不支持随机访问的数据集(如`TFRecordDataset`)进行数据加载并启用缓存服务时,需要保证整个数据集均存放于本地。在该场景下,若本地内存空间不足以存放所有数据,则必须启用溢出,将数据溢出至磁盘。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2eddce85",
   "metadata": {},
   "source": [
    "### 4. 插入缓存实例\n",
    "\n",
    "当前缓存服务既支持对原始数据集的缓存,也可以用于缓存经过数据增强处理后的数据。下例分别展示了两种使用方式。\n",
    "\n",
    "需要注意的是,两个例子均需要按照步骤3中的方法分别创建一个缓存实例,并在数据集加载或map操作中将所创建的`test_cache`作为`cache`参数分别传入。\n",
    "\n",
    "下面样例中使用到CIFAR-10数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4098f0a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from download import download\n",
    "import os\n",
    "import shutil\n",
    "\n",
    "url = \"https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz\"\n",
    "path = download(url, \"./datasets\", kind=\"tar.gz\", replace=True)\n",
    "\n",
    "test_path = \"./datasets/cifar-10-batches-bin/test\"\n",
    "train_path = \"./datasets/cifar-10-batches-bin/train\"\n",
    "os.makedirs(test_path, exist_ok=True)\n",
    "os.makedirs(train_path, exist_ok=True)\n",
    "if not os.path.exists(os.path.join(test_path, \"test_batch.bin\")):\n",
    "    shutil.move(\"./datasets/cifar-10-batches-bin/test_batch.bin\", test_path)\n",
    "[shutil.move(\"./datasets/cifar-10-batches-bin/\"+i, train_path) for i in os.listdir(\"./datasets/cifar-10-batches-bin/\") if os.path.isfile(\"./datasets/cifar-10-batches-bin/\"+i) and not i.endswith(\".html\") and not os.path.exists(os.path.join(train_path, i))]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "970ab371",
   "metadata": {},
   "source": [
    "解压后的数据集文件的目录结构如下:\n",
    "\n",
    "```text\n",
    "./datasets/cifar-10-batches-bin\n",
    "├── readme.html\n",
    "├── test\n",
    "│   └── test_batch.bin\n",
    "└── train\n",
    "    ├── batches.meta.txt\n",
    "    ├── data_batch_1.bin\n",
    "    ├── data_batch_2.bin\n",
    "    ├── data_batch_3.bin\n",
    "    ├── data_batch_4.bin\n",
    "    └── data_batch_5.bin\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51256547",
   "metadata": {},
   "source": [
    "#### 缓存原始数据集数据\n",
    "\n",
    "缓存原始数据集,经过MindSpore系统加载后的数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "68181766",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 image shape: (32, 32, 3)\n",
      "1 image shape: (32, 32, 3)\n",
      "2 image shape: (32, 32, 3)\n",
      "3 image shape: (32, 32, 3)\n"
     ]
    }
   ],
   "source": [
    "dataset_dir = \"./datasets/cifar-10-batches-bin/train\"\n",
    "\n",
    "# apply cache to dataset\n",
    "data = ds.Cifar10Dataset(dataset_dir=dataset_dir, num_samples=4, shuffle=False, num_parallel_workers=1, cache=test_cache)\n",
    "\n",
    "num_iter = 0\n",
    "for item in data.create_dict_iterator(num_epochs=1):\n",
    "    # in this example, each dictionary has a key \"image\"\n",
    "    print(\"{} image shape: {}\".format(num_iter, item[\"image\"].shape))\n",
    "    num_iter += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45a69948",
   "metadata": {},
   "source": [
    "通过`cache_admin --list_sessions`命令可以查看当前会话有四条数据,说明数据缓存成功。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e363eb7e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Listing sessions for server on port 50052\n",
      "\n",
      "     Session    Cache Id  Mem cached Disk cached  Avg cache size  Numa hit\n",
      "   780643335  2044459912           4         n/a            3226         4\n"
     ]
    }
   ],
   "source": [
    "!cache_admin --list_sessions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b7454dd",
   "metadata": {},
   "source": [
    "#### 缓存经过增强后数据\n",
    "\n",
    "缓存经过数据增强处理`transforms`后的数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "21d636ea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 image shape: (32, 32, 3)\n",
      "1 image shape: (32, 32, 3)\n",
      "2 image shape: (32, 32, 3)\n",
      "3 image shape: (32, 32, 3)\n",
      "4 image shape: (32, 32, 3)\n"
     ]
    }
   ],
   "source": [
    "import mindspore.dataset.vision as vision\n",
    "\n",
    "dataset_dir = \"./datasets/cifar-10-batches-bin/train\"\n",
    "\n",
    "# apply cache to dataset\n",
    "data = ds.Cifar10Dataset(dataset_dir=dataset_dir, num_samples=5, shuffle=False, num_parallel_workers=1)\n",
    "\n",
    "# apply cache to map\n",
    "rescale_op = vision.Rescale(1.0 / 255.0, -1.0)\n",
    "\n",
    "test_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=False)\n",
    "\n",
    "data = data.map(input_columns=[\"image\"], operations=rescale_op, cache=test_cache)\n",
    "\n",
    "num_iter = 0\n",
    "for item in data.create_dict_iterator(num_epochs=1):\n",
    "    # in this example, each dictionary has a keys \"image\"\n",
    "    print(\"{} image shape: {}\".format(num_iter, item[\"image\"].shape))\n",
    "    num_iter += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1233b33b",
   "metadata": {},
   "source": [
    "通过`cache_admin --list_sessions`命令可以查看当前会话有五条数据,说明数据缓存成功。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2fc3c0c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Listing sessions for server on port 50052\n",
      "\n",
      "     Session    Cache Id  Mem cached Disk cached  Avg cache size  Numa hit\n",
      "   780643335   112867845           5         n/a           12442         5\n",
      "   780643335  2044459912           4         n/a            3226         4\n"
     ]
    }
   ],
   "source": [
    "!cache_admin --list_sessions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b9d100b",
   "metadata": {},
   "source": [
    "### 5. 销毁缓存会话\n",
    "\n",
    "在训练结束后,可以选择将当前的缓存销毁并释放内存:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "destroy_session = 'cache_admin --destroy_session' + str(session_id)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3a475b41",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drop session successfully for server on port 50052\n"
     ]
    }
   ],
   "source": [
    "!destroy_session"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fea88d0",
   "metadata": {},
   "source": [
    "以上命令将销毁端口50052服务器中缓存会话id为1456416665的缓存。\n",
    "\n",
    "若选择不销毁缓存,则该缓存会话中的缓存数据将继续存在,用户下次启动训练脚本时可以继续使用该缓存。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c63672cc",
   "metadata": {},
   "source": [
    "### 6. 关闭缓存服务器\n",
    "\n",
    "使用完毕后,可以通过以下命令关闭缓存服务器,该操作将销毁当前服务器中存在的所有缓存会话并释放内存。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c92bd0ef",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cache server on port 50052 has been stopped successfully.\n"
     ]
    }
   ],
   "source": [
    "!cache_admin --stop"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13b48189",
   "metadata": {},
   "source": [
    "以上命令将关闭端口50052的服务器。\n",
    "\n",
    "若选择不关闭服务器,则服务器中已创建的缓存会话将保留,并供下次使用。下次训练时,用户可以新建缓存会话或重复使用已有缓存。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "480ac397",
   "metadata": {},
   "source": [
    "## 缓存共享\n",
    "\n",
    "对于单机多卡的分布式训练的场景,缓存还允许多个相同的训练脚本共享同一个缓存,共同从缓存中读写数据。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b30769cd",
   "metadata": {},
   "source": [
    "1. 启动缓存服务器"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f558cf45",
   "metadata": {},
   "source": [
    "```bash\n",
    "$cache_admin --start\n",
    "Cache server startup completed successfully!\n",
    "The cache server daemon has been created as process id 39337 and listening on port 50052\n",
    "Recommendation:\n",
    "Since the server is detached into its own daemon process, monitor the server logs (under /tmp/mindspore/cache/log) for any issues that may happen after startup\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ee9c725",
   "metadata": {},
   "source": [
    "2. 创建缓存会话\n",
    "\n",
    "创建启动Python训练的Shell脚本`cache.sh`,通过以下命令生成一个缓存会话id:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37ff6430",
   "metadata": {},
   "source": [
    "```shell\n",
    "#!/bin/bash\n",
    "# This shell script will launch parallel pipelines\n",
    "\n",
    "# get path to dataset directory\n",
    "if [ $# != 1 ]\n",
    "then\n",
    "        echo \"Usage: sh cache.sh DATASET_PATH\"\n",
    "exit 1\n",
    "fi\n",
    "dataset_path=$1\n",
    "\n",
    "# generate a session id that these parallel pipelines can share\n",
    "result=$(cache_admin -g 2>&1)\n",
    "rc=$?\n",
    "if [ $rc -ne 0 ]; then\n",
    "    echo \"some error\"\n",
    "    exit 1\n",
    "fi\n",
    "\n",
    "# grab the session id from the result string\n",
    "session_id=$(echo $result | awk '{print $NF}')\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27d9e2f2",
   "metadata": {},
   "source": [
    "3. 会话id传入训练脚本\n",
    "\n",
    "继续编写Shell脚本,添加以下命令在启动Python训练时将`session_id`以及其他参数传入:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0edd79d7",
   "metadata": {},
   "source": [
    "```bash\n",
    "# make the session_id available to the python scripts\n",
    "num_devices=4\n",
    "\n",
    "for p in $(seq 0 $((${num_devices}-1))); do\n",
    "    python my_training_script.py --num_devices \"$num_devices\" --device \"$p\" --session_id $session_id --dataset_path $dataset_path\n",
    "done\n",
    "```\n",
    "\n",
    "> 直接获取完整样例代码:[cache.sh](https://gitee.com/mindspore/docs/blob/r2.2/docs/sample_code/cache/cache.sh)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7b4ab30",
   "metadata": {},
   "source": [
    "4. 创建并应用缓存实例\n",
    "\n",
    "下面样例中使用到CIFAR-10数据集。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ba922b9",
   "metadata": {},
   "source": [
    "```text\n",
    "├─cache.sh\n",
    "├─my_training_script.py\n",
    "└─cifar-10-batches-bin\n",
    "    ├── batches.meta.txt\n",
    "    ├── data_batch_1.bin\n",
    "    ├── data_batch_2.bin\n",
    "    ├── data_batch_3.bin\n",
    "    ├── data_batch_4.bin\n",
    "    ├── data_batch_5.bin\n",
    "    ├── readme.html\n",
    "    └── test_batch.bin\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ab50904",
   "metadata": {},
   "source": [
    "创建并编写Python脚本`my_training_script.py`,通过以下代码接收传入的`session_id`,并在定义缓存实例时将其作为参数传入。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4e67457",
   "metadata": {},
   "source": [
    "```python\n",
    "    import argparse\n",
    "    import mindspore.dataset as ds\n",
    "\n",
    "    parser = argparse.ArgumentParser(description='Cache Example')\n",
    "    parser.add_argument('--num_devices', type=int, default=1, help='Device num.')\n",
    "    parser.add_argument('--device', type=int, default=0, help='Device id.')\n",
    "    parser.add_argument('--session_id', type=int, default=1, help='Session id.')\n",
    "    parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')\n",
    "    args_opt = parser.parse_args()\n",
    "\n",
    "    # apply cache to dataset\n",
    "    test_cache = ds.DatasetCache(session_id=args_opt.session_id, size=0, spilling=False)\n",
    "    dataset = ds.Cifar10Dataset(dataset_dir=args_opt.dataset_path, num_samples=4, shuffle=False, num_parallel_workers=1,\n",
    "                                num_shards=args_opt.num_devices, shard_id=args_opt.device, cache=test_cache)\n",
    "    num_iter = 0\n",
    "    for _ in dataset.create_dict_iterator():\n",
    "        num_iter += 1\n",
    "    print(\"Got {} samples on device {}\".format(num_iter, args_opt.device))\n",
    "```\n",
    "\n",
    "> 直接获取完整样例代码:[my_training_script.py](https://gitee.com/mindspore/docs/blob/r2.2/docs/sample_code/cache/my_training_script.py)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6915e0bf",
   "metadata": {},
   "source": [
    "5. 运行训练脚本\n",
    "\n",
    "运行Shell脚本`cache.sh`开启分布式训练:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad18f5d0",
   "metadata": {},
   "source": [
    "```bash\n",
    "$ sh cache.sh cifar-10-batches-bin/\n",
    "Got 4 samples on device 0\n",
    "Got 4 samples on device 1\n",
    "Got 4 samples on device 2\n",
    "Got 4 samples on device 3\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "651a090f",
   "metadata": {},
   "source": [
    "通过`cache_admin --list_sessions`命令可以查看当前会话中只有一组数据,说明缓存共享成功。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5570b305",
   "metadata": {},
   "source": [
    "```bash\n",
    "$ cache_admin --list_sessions\n",
    "Listing sessions for server on port 50052\n",
    "\n",
    "Session    Cache Id  Mem cached Disk cached  Avg cache size  Numa hit\n",
    "3392558708   821590605          16         n/a            3227        16\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54845c7d",
   "metadata": {},
   "source": [
    "6. 销毁缓存会话\n",
    "\n",
    "在训练结束后,可以选择将当前的缓存销毁并释放内存:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42db61ca",
   "metadata": {},
   "source": [
    "```bash\n",
    "$ cache_admin --destroy_session 3392558708\n",
    "Drop session successfully for server on port 50052\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e8e0aa6",
   "metadata": {},
   "source": [
    "7. 关闭缓存服务器\n",
    "\n",
    "使用完毕后,可以选择关闭缓存服务器:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9fbc03e",
   "metadata": {},
   "source": [
    "```bash\n",
    "$ cache_admin --stop\n",
    "Cache server on port 50052 has been stopped successfully.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "386e5187",
   "metadata": {},
   "source": [
    "## 缓存加速\n",
    "\n",
    "为了使较大的数据集在多台服务器之间共享,缓解单台服务器的磁盘空间需求,用户通常可以选择使用NFS(Network File System)即网络文件系统来存储数据集,如华为云-NFS存储服务器。\n",
    "\n",
    "然而,对于NFS数据集的访问通常开销较大,导致使用NFS数据集进行的训练用时较长。\n",
    "\n",
    "为了提高NFS数据集的训练性能,我们可以选择使用缓存服务,将数据集以Tensor的形式缓存在内存中。\n",
    "\n",
    "经过缓存后,后序的epoch就可以直接从内存中读取数据,避免了访问远程网络存储的开销。\n",
    "\n",
    "需要注意的是,在训练过程的数据处理流程中,数据集经加载后通常还需要进行一些带有随机性的增强操作,如`RandomCropDecodeResize`,若将缓存添加到该具有随机性的操作之后,将会导致第一次的增强操作结果被缓存下来,后序从缓存服务器中读取的结果均为第一次已缓存的数据,导致数据的随机性丢失,影响训练网络的精度。\n",
    "\n",
    "因此我们可以选择直接在数据集读取操作之后添加缓存。本节将采用这种方法,以MobileNetV2网络为样本,进行示例。\n",
    "\n",
    "完整示例代码请参考ModelZoo的[MobileNetV2](https://gitee.com/mindspore/models/tree/r2.2/official/cv/MobileNet/mobilenetv2)。\n",
    "\n",
    "1. 创建管理缓存的Shell脚本`cache_util.sh`:\n",
    "\n",
    "    ```bash\n",
    "    bootup_cache_server()\n",
    "    {\n",
    "      echo \"Booting up cache server...\"\n",
    "      result=$(cache_admin --start 2>&1)\n",
    "      echo \"${result}\"\n",
    "    }\n",
    "\n",
    "    generate_cache_session()\n",
    "    {\n",
    "      result=$(cache_admin -g | awk 'END {print $NF}')\n",
    "      echo \"${result}\"\n",
    "    }\n",
    "    ```\n",
    "\n",
    "    > 直接获取完整样例代码:[cache_util.sh](https://gitee.com/mindspore/docs/blob/r2.2/docs/sample_code/cache/cache_util.sh)\n",
    "\n",
    "2. 在启动NFS数据集训练的Shell脚本`run_train_nfs_cache.sh`中,为使用位于NFS上的数据集训练的场景开启缓存服务器并生成一个缓存会话保存在Shell变量`CACHE_SESSION_ID`中:\n",
    "\n",
    "    ```bash\n",
    "    CURPATH=\"${dirname \"$0\"}\"\n",
    "    source ${CURPATH}/cache_util.sh\n",
    "\n",
    "    bootup_cache_server\n",
    "    CACHE_SESSION_ID=$(generate_cache_session)\n",
    "    ```\n",
    "\n",
    "3. 在启动Python训练时将`CACHE_SESSION_ID`以及其他参数传入:\n",
    "\n",
    "    ```text\n",
    "    python train.py \\\n",
    "    --platform=$1 \\\n",
    "    --dataset_path=$5 \\\n",
    "    --pretrain_ckpt=$PRETRAINED_CKPT \\\n",
    "    --freeze_layer=$FREEZE_LAYER \\\n",
    "    --filter_head=$FILTER_HEAD \\\n",
    "    --enable_cache=True \\\n",
    "    --cache_session_id=$CACHE_SESSION_ID \\\n",
    "    &> log$i.log &\n",
    "    ```\n",
    "\n",
    "4. 在Python的参数解析脚本`args.py`的`train_parse_args()`函数中,通过以下代码接收传入的`cache_session_id`:\n",
    "\n",
    "    ```python\n",
    "    import argparse\n",
    "\n",
    "    def train_parse_args():\n",
    "    ...\n",
    "        train_parser.add_argument('--enable_cache',\n",
    "            type=ast.literal_eval,\n",
    "            default=False,\n",
    "            help='Caching the dataset in memory to speedup dataset processing, default is False.')\n",
    "        train_parser.add_argument('--cache_session_id',\n",
    "             type=str,\n",
    "             default=\"\",\n",
    "             help='The session id for cache service.')\n",
    "    train_args = train_parser.parse_args()\n",
    "    ```\n",
    "\n",
    "    并在Python的训练脚本`train.py`中调用`train_parse_args()`函数解析传入的`cache_session_id`等参数,并在定义数据集`dataset`时将其作为参数传入。\n",
    "\n",
    "    ```python\n",
    "    from src.args import train_parse_args\n",
    "    args_opt = train_parse_args()\n",
    "\n",
    "    dataset = create_dataset(\n",
    "        dataset_path=args_opt.dataset_path,\n",
    "        do_train=True,\n",
    "        config=config,\n",
    "        enable_cache=args_opt.enable_cache,\n",
    "        cache_session_id=args_opt.cache_session_id)\n",
    "    ```\n",
    "\n",
    "5. 在定义数据处理流程的Python脚本`dataset.py`中,根据传入的`enable_cache`以及`cache_session_id`参数,创建一个`DatasetCache`的实例并将其插入至`ImageFolderDataset`之后:\n",
    "\n",
    "    ```python\n",
    "    def create_dataset(dataset_path, do_train, config, repeat_num=1, enable_cache=False, cache_session_id=None):\n",
    "    ...\n",
    "        if enable_cache:\n",
    "            nfs_dataset_cache = ds.DatasetCache(session_id=int(cache_session_id), size=0)\n",
    "        else:\n",
    "            nfs_dataset_cache = None\n",
    "\n",
    "        if config.platform == \"Ascend\":\n",
    "            rank_size = int(os.getenv(\"RANK_SIZE\", '1'))\n",
    "            rank_id = int(os.getenv(\"RANK_ID\", '0'))\n",
    "            if rank_size == 1:\n",
    "                data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, cache=nfs_dataset_cache)\n",
    "            else:\n",
    "                data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, num_shards=rank_size, shard_id=rank_id, cache=nfs_dataset_cache)\n",
    "    ```\n",
    "\n",
    "6. 运行`run_train_nfs_cache.sh`,得到以下结果:\n",
    "\n",
    "    ```text\n",
    "    epoch: [  0/ 200], step:[ 2134/ 2135], loss:[4.682/4.682], time:[3364893.166], lr:[0.780]\n",
    "    epoch time: 3384387.999, per step time: 1585.193, avg loss: 4.682\n",
    "    epoch: [  1/ 200], step:[ 2134/ 2135], loss:[3.750/3.750], time:[430495.242], lr:[0.724]\n",
    "    epoch time: 431005.885, per step time: 201.876, avg loss: 4.286\n",
    "    epoch: [  2/ 200], step:[ 2134/ 2135], loss:[3.922/3.922], time:[420104.849], lr:[0.635]\n",
    "    epoch time: 420669.174, per step time: 197.035, avg loss: 3.534\n",
    "    epoch: [  3/ 200], step:[ 2134/ 2135], loss:[3.581/3.581], time:[420825.587], lr:[0.524]\n",
    "    epoch time: 421494.842, per step time: 197.421, avg loss: 3.417\n",
    "    ...\n",
    "    ```\n",
    "\n",
    "    下表展示了GPU服务器上使用缓存与不使用缓存的平均每个epoch时间对比:\n",
    "\n",
    "    ```text\n",
    "    | 4p, MobileNetV2, imagenet2012            | without cache | with cache |\n",
    "    | ---------------------------------------- | ------------- | ---------- |\n",
    "    | first epoch time                         | 1649s         | 3384s      |\n",
    "    | average epoch time (exclude first epoch) | 458s          | 421s       |\n",
    "    ```\n",
    "\n",
    "    可以看到使用缓存后,相比于不使用缓存的情况第一个epoch的完成时间增加了较多,这主要是由于缓存数据写入至缓存服务器的开销导致的。但是,在缓存数据写入之后随后的每个epoch都可以获得较大的性能提升。因此,训练的总epoch数目越多,使用缓存的收益将越明显。\n",
    "\n",
    "    以运行200个epoch为例,使用缓存可以使端到端的训练总用时从92791秒降低至87163秒,共计节省约5628秒。\n",
    "\n",
    "7. 使用完毕后,可以选择关闭缓存服务器:\n",
    "\n",
    "    ```text\n",
    "    $ cache_admin --stop\n",
    "    Cache server on port 50052 has been stopped successfully.\n",
    "    ```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9147850c",
   "metadata": {},
   "source": [
    "## 缓存性能调优\n",
    "\n",
    "使用缓存服务能够在一些场景下获得显著的性能提升,例如:\n",
    "\n",
    "- 缓存经过数据增强处理后的数据,尤其是当数据预处理管道中包含decode等高复杂度操作时。在该场景下,用户不需要在每个epoch重复执行数据增强操作,可节省较多时间。\n",
    "\n",
    "- 在简单网络的训练和推理过程中使用缓存服务。相比于复杂网络,简单网络的训练耗时占比更小,因此在该场景下应用缓存,能获得更显著的时间性能提升。\n",
    "\n",
    "然而,在以下场景中使用缓存可能不会获得明显的性能收益,例如:\n",
    "\n",
    "- 系统内存不足、缓存未命中等因素将导致缓存服务在时间性能上提升不明显。因此,可在使用缓存前检查可用系统内存是否充足,选择一个适当的缓存大小。\n",
    "\n",
    "- 过多缓存溢出会导致时间性能变差。因此,在使用可随机访问的数据集(如`ImageFolderDataset`)进行数据加载的场景,尽量不要允许缓存溢出至磁盘。\n",
    "\n",
    "- 在Bert等NLP类网络中使用缓存,通常不会取得性能提升。因为在NLP场景下通常不会使用到decode等高复杂度的数据增强操作。\n",
    "\n",
    "- 使用non-mappable数据集(如`TFRecordDataset`)的pipeline在第一个epoch的时间开销较大。根据当前的缓存机制,non-mappable数据集需要在第一个epoch训练开始前将所有数据写入缓存服务器中,因此这使得第一个epoch时间较长。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90db0fa9",
   "metadata": {},
   "source": [
    "## 缓存限制\n",
    "\n",
    "- 当前`GeneratorDataset`、`PaddedDataset`和`NumpySlicesDataset`等数据集类不支持缓存。其中,`GeneratorDataset`、`PaddedDataset`和`NumpySlicesDataset`属于`GeneratorOp`,在不支持的报错信息中会呈现“There is currently no support for GeneratorOp under cache”。\n",
    "- 经过`batch`、`concat`、`filter`、`repeat`、`skip`、`split`、`take`和`zip`处理后的数据不支持缓存。\n",
    "- 经过随机数据增强操作(如`RandomCrop`)后的数据不支持缓存。\n",
    "- 不支持在同个数据管道的不同位置嵌套使用同一个缓存实例。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}