{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 自动数据增强\n",
    "\n",
    "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/tutorials/zh_cn/dataset/mindspore_augment.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/tutorials/zh_cn/dataset/mindspore_augment.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/tutorials/source_zh_cn/dataset/augment.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MindSpore除了可以让用户自定义数据增强的使用,还提供了一种自动数据增强方式,可以基于特定策略自动对图像进行数据增强处理。\n",
    "\n",
    "下面分为**基于概率**和**基于回调参数**两种不同的自动数据增强方式进行介绍。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 基于概率的数据增强\n",
    "\n",
    "MindSpore提供了一系列基于概率的自动数据增强API,用户可以对各种数据增强操作进行随机选择与组合,使数据增强更加灵活。\n",
    "\n",
    "### RandomApply操作\n",
    "\n",
    "`RandomApply`操作接收一个数据增强操作列表,以一定的概率顺序执行列表中各数据增强操作,默认概率为0.5,否则都不执行。\n",
    "\n",
    "在下面的代码示例中,通过调用`RandomApply`接口来以0.5的概率来顺序执行`RandomCrop`和`RandomColorAdjust`操作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.dataset.vision as vision\n",
    "from mindspore.dataset.transforms import RandomApply\n",
    "\n",
    "transforms_list = [vision.RandomCrop(512), vision.RandomColorAdjust()]\n",
    "rand_apply = RandomApply(transforms_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RandomChoice\n",
    "\n",
    "`RandomChoice`操作接收一个数据增强操作列表`transforms`,从中随机选择一个数据增强操作执行。\n",
    "\n",
    "在下面的代码示例中,通过调用`RandomChoice`操作等概率地在`CenterCrop`和`RandomCrop`中选择一个操作执行。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.dataset.vision as vision\n",
    "from mindspore.dataset.transforms import RandomChoice\n",
    "\n",
    "transforms_list = [vision.CenterCrop(512), vision.RandomCrop(512)]\n",
    "rand_choice = RandomChoice(transforms_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RandomSelectSubpolicy\n",
    "\n",
    "`RandomSelectSubpolicy`操作接收一个预置策略列表,包含一系列子策略组合,每一子策略由若干个顺序执行的数据增强操作及其执行概率组成。\n",
    "\n",
    "对各图像先等概率随机选择一种子策略,再依照子策略中的概率顺序执行各个操作。\n",
    "\n",
    "在下面的代码示例中,预置了两条子策略:\n",
    "\n",
    "- 子策略1中包含`RandomRotation`、`RandomVerticalFlip`两个操作,概率分别为0.5、1.0。\n",
    "\n",
    "- 子策略2中包含`RandomRotation`和`RandomColorAdjust`两个操作,概率分别为1.0和0.2。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.dataset.vision as vision\n",
    "from mindspore.dataset.vision import RandomSelectSubpolicy\n",
    "\n",
    "policy_list = [\n",
    "    # policy 1: (transforms, probability)\n",
    "    [(vision.RandomRotation((45, 45)), 0.5),\n",
    "     (vision.RandomVerticalFlip(), 1.0)],\n",
    "    # policy 2: (transforms, probability)\n",
    "    [(vision.RandomRotation((90, 90)), 1.0),\n",
    "     (vision.RandomColorAdjust(), 0.2)]\n",
    "]\n",
    "\n",
    "policy = RandomSelectSubpolicy(policy_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 基于回调参数的数据增强\n",
    "\n",
    "MindSpore的`sync_wait`接口支持按训练数据的batch或epoch粒度,在训练过程中动态调整数据增强策略,用户可以设定阻塞条件来触发特定的数据增强操作。\n",
    "\n",
    "`sync_wait`将阻塞整个数据处理pipeline,直到`sync_update`触发用户预先定义的`callback`函数,两者需配合使用,对应说明如下:\n",
    "\n",
    "- sync_wait(condition_name, num_batch=1, callback=None):为数据集添加一个阻塞条件`condition_name`,当`sync_update`调用时执行指定的`callback`函数。\n",
    "\n",
    "- sync_update(condition_name, num_batch=None, data=None):用于释放对应`condition_name`的阻塞,并对`data`触发指定的`callback`函数。\n",
    "\n",
    "下面将演示基于回调参数的自动数据增强的用法。\n",
    "\n",
    "1. 用户预先定义`Augment`类,其中`preprocess`为自定义的数据增强函数,`update`为更新数据增强策略的回调函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "class Augment:\n",
    "    def __init__(self):\n",
    "        self.ep_num = 0\n",
    "        self.step_num = 0\n",
    "\n",
    "    def preprocess(self, input_):\n",
    "        return np.array((input_ + self.step_num ** self.ep_num - 1),)\n",
    "\n",
    "    def update(self, data):\n",
    "        self.ep_num = data['ep_num']\n",
    "        self.step_num = data['step_num']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 数据处理pipeline先回调自定义的增强策略更新函数`update`,然后在`map`操作中按更新后的策略来执行`preprocess`中定义的数据增强操作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.dataset as ds\n",
    "\n",
    "arr = list(range(1, 4))\n",
    "dataset = ds.NumpySlicesDataset(arr, shuffle=False)\n",
    "\n",
    "aug = Augment()\n",
    "dataset = dataset.sync_wait(condition_name=\"policy\", callback=aug.update)\n",
    "dataset = dataset.map(operations=[aug.preprocess])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3. 在每个step中调用`sync_update`进行数据增强策略的更新。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epcoh: 0, step:0, data :[Tensor(shape=[], dtype=Int64, value= 1)]\n",
      "epcoh: 0, step:1, data :[Tensor(shape=[], dtype=Int64, value= 2)]\n",
      "epcoh: 0, step:2, data :[Tensor(shape=[], dtype=Int64, value= 3)]\n",
      "epcoh: 1, step:3, data :[Tensor(shape=[], dtype=Int64, value= 1)]\n",
      "epcoh: 1, step:4, data :[Tensor(shape=[], dtype=Int64, value= 5)]\n",
      "epcoh: 1, step:5, data :[Tensor(shape=[], dtype=Int64, value= 7)]\n",
      "epcoh: 2, step:6, data :[Tensor(shape=[], dtype=Int64, value= 6)]\n",
      "epcoh: 2, step:7, data :[Tensor(shape=[], dtype=Int64, value= 50)]\n",
      "epcoh: 2, step:8, data :[Tensor(shape=[], dtype=Int64, value= 66)]\n",
      "epcoh: 3, step:9, data :[Tensor(shape=[], dtype=Int64, value= 81)]\n",
      "epcoh: 3, step:10, data :[Tensor(shape=[], dtype=Int64, value= 1001)]\n",
      "epcoh: 3, step:11, data :[Tensor(shape=[], dtype=Int64, value= 1333)]\n",
      "epcoh: 4, step:12, data :[Tensor(shape=[], dtype=Int64, value= 1728)]\n",
      "epcoh: 4, step:13, data :[Tensor(shape=[], dtype=Int64, value= 28562)]\n",
      "epcoh: 4, step:14, data :[Tensor(shape=[], dtype=Int64, value= 38418)]\n"
     ]
    }
   ],
   "source": [
    "epochs = 5\n",
    "itr = dataset.create_tuple_iterator(num_epochs=epochs)\n",
    "\n",
    "step_num = 0\n",
    "for ep_num in range(epochs):\n",
    "    for data in itr:\n",
    "        print(\"epcoh: {}, step:{}, data :{}\".format(ep_num, step_num, data))\n",
    "        step_num += 1\n",
    "        dataset.sync_update(condition_name=\"policy\",\n",
    "                            data={'ep_num': ep_num, 'step_num': step_num})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ImageNet自动数据增强\n",
    "\n",
    "下面以ImageNet数据集上实现AutoAugment作为示例。\n",
    "\n",
    "针对ImageNet数据集的数据增强策略包含25条子策略,每条子策略中包含两种变换,针对一个batch中的每张图像随机挑选一个子策略的组合,以预定的概率来决定是否执行子策略中的每种变换。\n",
    "\n",
    "用户可以使用MindSpore中`mindspore.dataset.vision`模块的`RandomSelectSubpolicy`接口来实现AutoAugment,在ImageNet分类训练中标准的数据增强方式分以下几个步骤:\n",
    "\n",
    "- `RandomCropDecodeResize`:随机裁剪后进行解码。\n",
    "- `RandomHorizontalFlip`:水平方向上随机翻转。\n",
    "- `Normalize`:归一化。\n",
    "- `HWC2CHW`:图片通道变化。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 定义MindSpore操作到AutoAugment操作的映射:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.dataset.vision as vision\n",
    "import mindspore.dataset.transforms as transforms\n",
    "\n",
    "# define Auto Augmentation operations\n",
    "PARAMETER_MAX = 10\n",
    "\n",
    "def float_parameter(level, maxval):\n",
    "    return float(level) * maxval /  PARAMETER_MAX\n",
    "\n",
    "def int_parameter(level, maxval):\n",
    "    return int(level * maxval / PARAMETER_MAX)\n",
    "\n",
    "def shear_x(level):\n",
    "    transforms_list = []\n",
    "    v = float_parameter(level, 0.3)\n",
    "\n",
    "    transforms_list.append(vision.RandomAffine(degrees=0, shear=(-v, -v)))\n",
    "    transforms_list.append(vision.RandomAffine(degrees=0, shear=(v, v)))\n",
    "    return transforms.RandomChoice(transforms_list)\n",
    "\n",
    "def shear_y(level):\n",
    "    transforms_list = []\n",
    "    v = float_parameter(level, 0.3)\n",
    "\n",
    "    transforms_list.append(vision.RandomAffine(degrees=0, shear=(0, 0, -v, -v)))\n",
    "    transforms_list.append(vision.RandomAffine(degrees=0, shear=(0, 0, v, v)))\n",
    "    return transforms.RandomChoice(transforms_list)\n",
    "\n",
    "def translate_x(level):\n",
    "    transforms_list = []\n",
    "    v = float_parameter(level, 150 / 331)\n",
    "\n",
    "    transforms_list.append(vision.RandomAffine(degrees=0, translate=(-v, -v)))\n",
    "    transforms_list.append(vision.RandomAffine(degrees=0, translate=(v, v)))\n",
    "    return transforms.RandomChoice(transforms_list)\n",
    "\n",
    "def translate_y(level):\n",
    "    transforms_list = []\n",
    "    v = float_parameter(level, 150 / 331)\n",
    "\n",
    "    transforms_list.append(vision.RandomAffine(degrees=0, translate=(0, 0, -v, -v)))\n",
    "    transforms_list.append(vision.RandomAffine(degrees=0, translate=(0, 0, v, v)))\n",
    "    return transforms.RandomChoice(transforms_list)\n",
    "\n",
    "def color_impl(level):\n",
    "    v = float_parameter(level, 1.8) + 0.1\n",
    "    return vision.RandomColor(degrees=(v, v))\n",
    "\n",
    "def rotate_impl(level):\n",
    "    transforms_list = []\n",
    "    v = int_parameter(level, 30)\n",
    "\n",
    "    transforms_list.append(vision.RandomRotation(degrees=(-v, -v)))\n",
    "    transforms_list.append(vision.RandomRotation(degrees=(v, v)))\n",
    "    return transforms.RandomChoice(transforms_list)\n",
    "\n",
    "def solarize_impl(level):\n",
    "    level = int_parameter(level, 256)\n",
    "    v = 256 - level\n",
    "    return vision.RandomSolarize(threshold=(0, v))\n",
    "\n",
    "def posterize_impl(level):\n",
    "    level = int_parameter(level, 4)\n",
    "    v = 4 - level\n",
    "    return vision.RandomPosterize(bits=(v, v))\n",
    "\n",
    "def contrast_impl(level):\n",
    "    v = float_parameter(level, 1.8) + 0.1\n",
    "    return vision.RandomColorAdjust(contrast=(v, v))\n",
    "\n",
    "def autocontrast_impl(level):\n",
    "    return vision.AutoContrast()\n",
    "\n",
    "def sharpness_impl(level):\n",
    "    v = float_parameter(level, 1.8) + 0.1\n",
    "    return vision.RandomSharpness(degrees=(v, v))\n",
    "\n",
    "def brightness_impl(level):\n",
    "    v = float_parameter(level, 1.8) + 0.1\n",
    "    return vision.RandomColorAdjust(brightness=(v, v))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 定义ImageNet数据集的AutoAugment策略:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the Auto Augmentation policy\n",
    "imagenet_policy = [\n",
    "    [(posterize_impl(8), 0.4), (rotate_impl(9), 0.6)],\n",
    "    [(solarize_impl(5), 0.6), (autocontrast_impl(5), 0.6)],\n",
    "    [(vision.Equalize(), 0.8), (vision.Equalize(), 0.6)],\n",
    "    [(posterize_impl(7), 0.6), (posterize_impl(6), 0.6)],\n",
    "\n",
    "    [(vision.Equalize(), 0.4), (solarize_impl(4), 0.2)],\n",
    "    [(vision.Equalize(), 0.4), (rotate_impl(8), 0.8)],\n",
    "    [(solarize_impl(3), 0.6), (vision.Equalize(), 0.6)],\n",
    "    [(posterize_impl(5), 0.8), (vision.Equalize(), 1.0)],\n",
    "    [(rotate_impl(3), 0.2), (solarize_impl(8), 0.6)],\n",
    "    [(vision.Equalize(), 0.6), (posterize_impl(6), 0.4)],\n",
    "\n",
    "    [(rotate_impl(8), 0.8), (color_impl(0), 0.4)],\n",
    "    [(rotate_impl(9), 0.4), (vision.Equalize(), 0.6)],\n",
    "    [(vision.Equalize(), 0.0), (vision.Equalize(), 0.8)],\n",
    "    [(vision.Invert(), 0.6), (vision.Equalize(), 1.0)],\n",
    "    [(color_impl(4), 0.6), (contrast_impl(8), 1.0)],\n",
    "\n",
    "    [(rotate_impl(8), 0.8), (color_impl(2), 1.0)],\n",
    "    [(color_impl(8), 0.8), (solarize_impl(7), 0.8)],\n",
    "    [(sharpness_impl(7), 0.4), (vision.Invert(), 0.6)],\n",
    "    [(shear_x(5), 0.6), (vision.Equalize(), 1.0)],\n",
    "    [(color_impl(0), 0.4), (vision.Equalize(), 0.6)],\n",
    "\n",
    "    [(vision.Equalize(), 0.4), (solarize_impl(4), 0.2)],\n",
    "    [(solarize_impl(5), 0.6), (autocontrast_impl(5), 0.6)],\n",
    "    [(vision.Invert(), 0.6), (vision.Equalize(), 1.0)],\n",
    "    [(color_impl(4), 0.6), (contrast_impl(8), 1.0)],\n",
    "    [(vision.Equalize(), 0.8), (vision.Equalize(), 0.6)],\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3. 在`RandomCropDecodeResize`操作后插入AutoAugment变换。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.dataset as ds\n",
    "import mindspore as ms\n",
    "\n",
    "\n",
    "def create_dataset(dataset_path, train, repeat_num=1,\n",
    "                   batch_size=32, shuffle=False, num_samples=5):\n",
    "    # create a train or eval imagenet2012 dataset for ResNet-50\n",
    "    dataset = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8,\n",
    "                                    shuffle=shuffle, decode=True)\n",
    "\n",
    "    image_size = 224\n",
    "\n",
    "    # define map operations\n",
    "    if train:\n",
    "        trans = RandomSelectSubpolicy(imagenet_policy)\n",
    "    else:\n",
    "        trans = [vision.Resize(256),\n",
    "                 vision.CenterCrop(image_size)]\n",
    "    type_cast_op = transforms.TypeCast(ms.int32)\n",
    "\n",
    "    # map images and labes\n",
    "    dataset = dataset.map(operations=[vision.Resize(256), vision.CenterCrop(image_size)], input_columns=\"image\")\n",
    "    dataset = dataset.map(operations=trans, input_columns=\"image\")\n",
    "    dataset = dataset.map(operations=type_cast_op, input_columns=\"label\")\n",
    "\n",
    "    # apply the batch and repeat operation\n",
    "    dataset = dataset.batch(batch_size, drop_remainder=True)\n",
    "    dataset = dataset.repeat(repeat_num)\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "4. 验证自动数据增强效果:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from download import download\n",
    "\n",
    "# Define the path to image folder directory.\n",
    "url = \"https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/ImageNetSimilar.tar.gz\"\n",
    "download(url, \"./\", kind=\"tar.gz\", replace=True)\n",
    "dataset = create_dataset(dataset_path=\"ImageNetSimilar\",\n",
    "                         train=True,\n",
    "                         batch_size=5,\n",
    "                         shuffle=False)\n",
    "\n",
    "epochs = 5\n",
    "columns = 5\n",
    "rows = 5\n",
    "fig = plt.figure(figsize=(8, 8))\n",
    "itr = dataset.create_dict_iterator(num_epochs=epochs)\n",
    "\n",
    "for ep_num in range(epochs):\n",
    "    step_num = 0\n",
    "    for data in itr:\n",
    "        for index in range(rows):\n",
    "            fig.add_subplot(rows, columns, step_num * rows + index + 1)\n",
    "            plt.imshow(data['image'].asnumpy()[index])\n",
    "        step_num += 1\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 为了更好地演示效果,此处只加载5张图片,且读取时不进行`shuffle`操作,自动数据增强时也不进行`Normalize`和`HWC2CHW`操作。\n",
    "\n",
    "![augment](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/dataset/images/auto_augmentation.png)\n",
    "\n",
    "运行结果可以看到,batch中每张图像的增强效果,垂直方向表示1个batch的5张图像,水平方向表示5个batch。\n",
    "\n",
    "## 参考文献\n",
    "\n",
    "[1] Ekin D. Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasudevan, Quoc V. Le [AutoAugment: Learning Augmentation Policies from Data](https://arxiv.org/abs/1805.09501)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "e92e3b0e72260407a1e4d16fabe2efc1463db1c235b8d61a4b02ddd7ca8a9a6a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}