{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# 自动数据增强\n",
    "\n",
    "`Ascend` `GPU` `CPU` `数据准备`\n",
    "\n",
    "[![](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9taW5kc3BvcmUtd2Vic2l0ZS5vYnMuY24tbm9ydGgtNC5teWh1YXdlaWNsb3VkLmNvbS9ub3RlYm9vay9tYXN0ZXIvcHJvZ3JhbW1pbmdfZ3VpZGUvemhfY24vbWluZHNwb3JlX2F1dG9fYXVnbWVudGF0aW9uLmlweW5i&imageid=65f636a0-56cf-49df-b941-7d2a07ba8c8c) [![](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.5/programming_guide/zh_cn/mindspore_auto_augmentation.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.5/programming_guide/zh_cn/mindspore_auto_augmentation.py) [![](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.5/docs/mindspore/programming_guide/source_zh_cn/auto_augmentation.ipynb)"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 概述\n",
    "\n",
    "MindSpore除了可以让用户自定义数据增强的使用,还提供了一种自动数据增强方式,可以基于特定策略自动对图像进行数据增强处理。\n",
    "\n",
    "自动数据增强主要分为基于概率的自动数据增强和基于回调参数的自动数据增强。\n",
    "\n",
    "> 完整示例参见[应用自动数据增强](https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.5/enable_auto_augmentation.html)。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 基于概率的自动数据增强\n",
    "\n",
    "MindSpore提供了一系列基于概率的自动数据增强API,用户可以对各种数据增强操作进行随机选择与组合,使数据增强更加灵活。\n",
    "\n",
    "关于API的详细说明,可以参见[API文档](https://www.mindspore.cn/docs/api/zh-CN/r1.5/api_python/mindspore.dataset.transforms.html)。\n",
    "\n",
    "### RandomApply\n",
    "\n",
    "API接收一个数据增强操作列表`transforms`,以一定的概率顺序执行列表中各数据增强操作,默认概率为0.5,否则都不执行。\n",
    "\n",
    "在下面的代码示例中,以0.5的概率来顺序执行`RandomCrop`和`RandomColorAdjust`操作,否则都不执行。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "import mindspore.dataset.vision.c_transforms as c_vision\n",
    "from mindspore.dataset.transforms.c_transforms import RandomApply\n",
    "\n",
    "rand_apply_list = RandomApply([c_vision.RandomCrop(512), c_vision.RandomColorAdjust()])"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### RandomChoice\n",
    "\n",
    "API接收一个数据增强操作列表`transforms`,从中随机选择一个数据增强操作执行。\n",
    "\n",
    "在下面的代码示例中,等概率地在`CenterCrop`和`RandomCrop`中选择一个操作执行。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "import mindspore.dataset.vision.c_transforms as c_vision\n",
    "from mindspore.dataset.transforms.c_transforms import RandomChoice\n",
    "\n",
    "rand_choice = RandomChoice([c_vision.CenterCrop(512), c_vision.RandomCrop(512)])"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### RandomSelectSubpolicy\n",
    "\n",
    "API接收一个预置策略列表,包含一系列子策略组合,每一子策略由若干个顺序执行的数据增强操作及其执行概率组成。\n",
    "\n",
    "对各图像先等概率随机选择一种子策略,再依照子策略中的概率顺序执行各个操作。\n",
    "\n",
    "在下面的代码示例中,预置了两条子策略,子策略1中包含`RandomRotation`、`RandomVerticalFlip`和`RandomColorAdjust`三个操作,概率分别为0.5、1.0和0.8;子策略2中包含`RandomRotation`和`RandomColorAdjust`两个操作,概率分别为1.0和0.2。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "source": [
    "import mindspore.dataset.vision.c_transforms as c_vision\n",
    "from mindspore.dataset.vision.c_transforms import RandomSelectSubpolicy\n",
    "\n",
    "policy_list = [\n",
    "      [(c_vision.RandomRotation((45, 45)), 0.5), (c_vision.RandomVerticalFlip(), 1.0), (c_vision.RandomColorAdjust(), 0.8)],\n",
    "      [(c_vision.RandomRotation((90, 90)), 1.0), (c_vision.RandomColorAdjust(), 0.2)]\n",
    "      ]\n",
    "policy = RandomSelectSubpolicy(policy_list)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 基于回调参数的自动数据增强\n",
    "\n",
    "MindSpore的`sync_wait`接口支持按batch或epoch粒度在训练过程中动态调整数据增强策略,用户可以设定阻塞条件触发特定的数据增强操作。\n",
    "\n",
    "`sync_wait`将阻塞整个数据处理pipeline直到`sync_update`触发用户预先定义的`callback`函数,两者需配合使用,对应说明如下:\n",
    "\n",
    "- sync_wait(condition_name, num_batch=1, callback=None)\n",
    "\n",
    "    该API为数据集添加一个阻塞条件`condition_name`,当`sync_update`调用时执行指定的`callback`函数。\n",
    "\n",
    "- sync_update(condition_name, num_batch=None, data=None)\n",
    "\n",
    "    该API用于释放对应`condition_name`的阻塞,并对`data`触发指定的`callback`函数。\n",
    "\n",
    "下面将演示基于回调参数的自动数据增强的用法。\n",
    "\n",
    "1. 用户预先定义`Augment`类,其中`preprocess`为自定义的数据增强函数,`update`为更新数据增强策略的回调函数。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "source": [
    "import mindspore.dataset.vision.py_transforms as transforms\n",
    "import mindspore.dataset as ds\n",
    "import numpy as np\n",
    "\n",
    "class Augment:\n",
    "    def __init__(self):\n",
    "        self.ep_num = 0\n",
    "        self.step_num = 0\n",
    "\n",
    "    def preprocess(self, input_):\n",
    "        return (np.array((input_ + self.step_num ** self.ep_num - 1), ))\n",
    "\n",
    "    def update(self, data):\n",
    "        self.ep_num = data['ep_num']\n",
    "        self.step_num = data['step_num']"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "2. 数据处理pipeline先回调自定义的增强策略更新函数`update`,然后在`map`操作中按更新后的策略来执行`preprocess`中定义的数据增强操作。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "source": [
    "arr = list(range(1, 4))\n",
    "dataset = ds.NumpySlicesDataset(arr, shuffle=False)\n",
    "aug = Augment()\n",
    "dataset = dataset.sync_wait(condition_name=\"policy\", callback=aug.update)\n",
    "dataset = dataset.map(operations=[aug.preprocess])"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "3. 在每个step中调用`sync_update`进行数据增强策略的更新。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "source": [
    "epochs = 5\n",
    "itr = dataset.create_tuple_iterator(num_epochs=epochs)\n",
    "step_num = 0\n",
    "for ep_num in range(epochs):\n",
    "    for data in itr:\n",
    "        print(\"epcoh: {}, step:{}, data :{}\".format(ep_num, step_num, data))\n",
    "        step_num += 1\n",
    "        dataset.sync_update(condition_name=\"policy\", data={'ep_num': ep_num, 'step_num': step_num})"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "epcoh: 0, step:0, data :[Tensor(shape=[], dtype=Int64, value= 1)]\n",
      "epcoh: 0, step:1, data :[Tensor(shape=[], dtype=Int64, value= 2)]\n",
      "epcoh: 0, step:2, data :[Tensor(shape=[], dtype=Int64, value= 3)]\n",
      "epcoh: 1, step:3, data :[Tensor(shape=[], dtype=Int64, value= 1)]\n",
      "epcoh: 1, step:4, data :[Tensor(shape=[], dtype=Int64, value= 5)]\n",
      "epcoh: 1, step:5, data :[Tensor(shape=[], dtype=Int64, value= 7)]\n",
      "epcoh: 2, step:6, data :[Tensor(shape=[], dtype=Int64, value= 6)]\n",
      "epcoh: 2, step:7, data :[Tensor(shape=[], dtype=Int64, value= 50)]\n",
      "epcoh: 2, step:8, data :[Tensor(shape=[], dtype=Int64, value= 66)]\n",
      "epcoh: 3, step:9, data :[Tensor(shape=[], dtype=Int64, value= 81)]\n",
      "epcoh: 3, step:10, data :[Tensor(shape=[], dtype=Int64, value= 1001)]\n",
      "epcoh: 3, step:11, data :[Tensor(shape=[], dtype=Int64, value= 1333)]\n",
      "epcoh: 4, step:12, data :[Tensor(shape=[], dtype=Int64, value= 1728)]\n",
      "epcoh: 4, step:13, data :[Tensor(shape=[], dtype=Int64, value= 28562)]\n",
      "epcoh: 4, step:14, data :[Tensor(shape=[], dtype=Int64, value= 38418)]\n"
     ]
    }
   ],
   "metadata": {}
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}