{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6b0c6c44",
   "metadata": {},
   "source": [
    "# 网络搭建\n",
    "\n",
    "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.1/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.1/zh_cn/migration_guide/model_development/mindspore_model_and_cell.ipynb)\n",
    "[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.1/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.1/zh_cn/migration_guide/model_development/mindspore_model_and_cell.py)\n",
    "[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.1/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/source_zh_cn/migration_guide/model_development/model_and_cell.ipynb)\n",
    "\n",
    "## 网络基本构成单元 Cell\n",
    "\n",
    "MindSpore的网络搭建主要使用[Cell](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell)进行图的构造,用户需要定义一个类继承 `Cell` 这个基类,在 `init` 里声明需要使用的API及子模块,在 `construct` 里进行计算, `Cell` 在 `GRAPH_MODE` (静态图模式)下将编译为一张计算图,在 `PYNATIVE_MODE` (动态图模式)下作为神经网络的基础模块。一个基本的 `Cell` 搭建过程如下所示:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7be6a0d1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:11.932322Z",
     "start_time": "2022-09-08T08:54:08.719303Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Parameter (name=net.weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]\n"
     ]
    }
   ],
   "source": [
    "import mindspore.nn as nn\n",
    "import mindspore.ops as ops\n",
    "\n",
    "class MyCell(nn.Cell):\n",
    "    def __init__(self, forward_net):\n",
    "        super(MyCell, self).__init__(auto_prefix=True)\n",
    "        self.net = forward_net\n",
    "        self.relu = ops.ReLU()\n",
    "\n",
    "    def construct(self, x):\n",
    "        y = self.net(x)\n",
    "        return self.relu(y)\n",
    "\n",
    "inner_net = nn.Conv2d(120, 240, 4, has_bias=False)\n",
    "my_net = MyCell(inner_net)\n",
    "print(my_net.trainable_params())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afddb76e",
   "metadata": {},
   "source": [
    "参数的名字一般是根据`__init__`定义的对象名字和参数定义时用的名字组成的,比如上面的例子中,卷积的参数名为`net.weight`,其中,`net`是`self.net = forward_net`中的对象名,`weight`是Conv2d中定义卷积的参数时的`name`:`self.weight = Parameter(initializer(self.weight_init, shape), name='weight')`。\n",
    "\n",
    "Cell提供了`auto_prefix`接口用来判断Cell中的参数名是否加对象名这层信息,默认是`True`,也就是加对象名。如果`auto_prefix`设置为`False`,则上面这个例子中打印的`Parameter`的`name`是`weight`。通常骨干网络`auto_prefix`应设置为True。用于训练的优化器、 :class:`mindspore.nn.TrainOneStepCell` 等,应设置为False,以避免骨干网络的权重参数名被误改。\n",
    "\n",
    "### 单元测试\n",
    "\n",
    "搭建完`Cell`之后,最好对每个`Cell`构建一个单元测试方法与对标代码比较,比如上面的例子,其PyTorch的构建代码为:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3e55faef",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:12.986099Z",
     "start_time": "2022-09-08T08:54:11.935032Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([240, 120, 4, 4])\n"
     ]
    }
   ],
   "source": [
    "import torch.nn as torch_nn\n",
    "\n",
    "class MyCell_pt(torch_nn.Module):\n",
    "    def __init__(self, forward_net):\n",
    "        super(MyCell_pt, self).__init__()\n",
    "        self.net = forward_net\n",
    "        self.relu = torch_nn.ReLU()\n",
    "\n",
    "    def forward(self, x):\n",
    "        y = self.net(x)\n",
    "        return self.relu(y)\n",
    "\n",
    "inner_net_pt = torch_nn.Conv2d(120, 240, kernel_size=4, bias=False)\n",
    "pt_net = MyCell_pt(inner_net_pt)\n",
    "for i in pt_net.parameters():\n",
    "    print(i.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df72ed18",
   "metadata": {},
   "source": [
    "有了构建`Cell`的脚本,需要使用相同的输入数据和参数,对输出做比较:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80a2fad0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:34:48.798328Z",
     "start_time": "2022-09-08T08:34:47.920670Z"
    }
   },
   "source": [
    "```python\n",
    "import numpy as np\n",
    "import mindspore as ms\n",
    "import torch\n",
    "\n",
    "x = np.random.uniform(-1, 1, (2, 120, 12, 12)).astype(np.float32)\n",
    "for m in pt_net.modules():\n",
    "    if isinstance(m, torch_nn.Conv2d):\n",
    "        torch_nn.init.constant_(m.weight, 0.1)\n",
    "\n",
    "for _, cell in my_net.cells_and_names():\n",
    "    if isinstance(cell, nn.Conv2d):\n",
    "        cell.weight.set_data(ms.common.initializer.initializer(0.1, cell.weight.shape, cell.weight.dtype))\n",
    "\n",
    "y_ms = my_net(ms.Tensor(x))\n",
    "y_pt = pt_net(torch.from_numpy(x))\n",
    "diff = np.max(np.abs(y_ms.asnumpy() - y_pt.detach().numpy()))\n",
    "print(diff)\n",
    "\n",
    "# ValueError: operands could not be broadcast together with shapes (2,240,12,12) (2,240,9,9)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1483b031",
   "metadata": {},
   "source": [
    "可以发现MindSpore和PyTorch的输出不一样,什么原因呢?\n",
    "\n",
    "查询[API差异文档](https://www.mindspore.cn/docs/zh-CN/r2.1/note/api_mapping/pytorch_diff/Conv2d.html)发现,`Conv2d`的默认参数在MindSpore和PyTorch上有区别,\n",
    "MindSpore默认使用`same`模式,PyTorch默认使用`pad`模式,迁移时需要改一下MindSpore `Conv2d`的`pad_mode`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e7f3e820",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:13.137339Z",
     "start_time": "2022-09-08T08:54:12.988142Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.9355288e-06\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "import torch\n",
    "\n",
    "inner_net = nn.Conv2d(120, 240, 4, has_bias=False, pad_mode=\"pad\")\n",
    "my_net = MyCell(inner_net)\n",
    "\n",
    "# 构造随机输入\n",
    "x = np.random.uniform(-1, 1, (2, 120, 12, 12)).astype(np.float32)\n",
    "for m in pt_net.modules():\n",
    "    if isinstance(m, torch_nn.Conv2d):\n",
    "        # 固定PyTorch初始化参数\n",
    "        torch_nn.init.constant_(m.weight, 0.1)\n",
    "\n",
    "for _, cell in my_net.cells_and_names():\n",
    "    if isinstance(cell, nn.Conv2d):\n",
    "        # 固定MindSpore初始化参数\n",
    "        cell.weight.set_data(ms.common.initializer.initializer(0.1, cell.weight.shape, cell.weight.dtype))\n",
    "\n",
    "y_ms = my_net(ms.Tensor(x))\n",
    "y_pt = pt_net(torch.from_numpy(x))\n",
    "diff = np.max(np.abs(y_ms.asnumpy() - y_pt.detach().numpy()))\n",
    "print(diff)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "febfeda5",
   "metadata": {},
   "source": [
    "整体误差在万分之一左右,基本符合预期。**在迁移Cell的过程中最好对每个Cell都做一次单元测试,保证迁移的一致性。**\n",
    "\n",
    "### Cell常用的方法介绍\n",
    "\n",
    "`Cell`是MindSpore中神经网络的基本构成单元,提供了很多设置标志位以及好用的方法,下面来介绍一些常用的方法。\n",
    "\n",
    "#### 手动混合精度\n",
    "\n",
    "MindSpore提供了一种自动混合精度的方法,详见[Model](https://www.mindspore.cn/docs/en/r2.1/api_python/train/mindspore.train.Model.html#mindspore.train.Model)的amp_level属性。\n",
    "\n",
    "但是有的时候开发网络时希望混合精度策略更加的灵活,MindSpore也提供了[to_float](https://mindspore.cn/docs/zh-CN/r2.1/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell.to_float)的方法手动地添加混合精度。\n",
    "\n",
    "`to_float(dst_type)`: 在`Cell`和所有子`Cell`的输入上添加类型转换,以使用特定的浮点类型运行。\n",
    "\n",
    "如果 `dst_type` 是 `ms.float16` ,`Cell`的所有输入(包括作为常量的input, `Parameter`, `Tensor`)都会被转换为`float16`。例如,我想将一个网络里所有的BN和loss改成`float32`类型,其余操作是`float16`类型,可以这么做:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9b479a4b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:15.400077Z",
     "start_time": "2022-09-08T08:54:13.140480Z"
    }
   },
   "outputs": [],
   "source": [
    "import mindspore as ms\n",
    "from mindspore import nn\n",
    "\n",
    "# 定义模型\n",
    "class Network(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.layer1 = nn.SequentialCell([\n",
    "            nn.Conv2d(3, 12, kernel_size=3, pad_mode='pad', padding=1),\n",
    "            nn.BatchNorm2d(12),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        ])\n",
    "        self.layer2 = nn.SequentialCell([\n",
    "            nn.Conv2d(12, 4, kernel_size=3, pad_mode='pad', padding=1),\n",
    "            nn.BatchNorm2d(4),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        ])\n",
    "        self.pool = nn.AdaptiveMaxPool2d((5, 5))\n",
    "        self.fc = nn.Dense(100, 10)\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.layer1(x)\n",
    "        x = self.layer2(x)\n",
    "        x = self.pool(x)\n",
    "        x = x.view((-1, 100))\n",
    "        out = nn.Dense(x)\n",
    "        return out\n",
    "\n",
    "net = Network()\n",
    "net.to_float(ms.float16)  # 将net里所有的操作加float16的标志,框架会在编译时在输入加cast方法\n",
    "for _, cell in net.cells_and_names():\n",
    "    if isinstance(cell, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):\n",
    "        cell.to_float(ms.float32)\n",
    "\n",
    "loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean').to_float(ms.float32)\n",
    "net_with_loss = nn.WithLossCell(net, loss_fn=loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5b15c18",
   "metadata": {},
   "source": [
    "自定义的`to_float`和Model里的`amp_level`冲突,使用自定义的混合精度就不要设置Model里的`amp_level`。\n",
    "\n",
    "#### 自定义初始化参数\n",
    "\n",
    "MindSpore封装的高阶API里一般会给参数一个默认的初始化,有时候这个初始化分布与需要使用的初始化、PyTorch的初始化不一致,此时需要进行自定义初始化。[网络参数初始化](https://mindspore.cn/tutorials/zh-CN/r2.1/advanced/modules/initializer.html#自定义参数初始化)介绍了一种在使用API属性进行初始化的方法,这里介绍一种利用Cell进行参数初始化的方法。\n",
    "\n",
    "参数的相关介绍请参考[网络参数](https://www.mindspore.cn/tutorials/zh-CN/r2.1/advanced/modules/initializer.html),本节主要以`Cell`为切入口,举例获取`Cell`中的所有参数,并举例说明怎样给`Cell`里的参数进行初始化。\n",
    "\n",
    "> 注意本节的方法不能在`construct`里执行,在网络中修改参数的值请使用[assign](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/ops/mindspore.ops.assign.html)。\n",
    "\n",
    "[set_data(data, slice_shape=False)](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/mindspore/mindspore.Parameter.html?highlight=set_data#mindspore.Parameter.set_data)设置参数数据。\n",
    "\n",
    "MindSpore支持的参数初始化方法参考[mindspore.common.initializer](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/mindspore.common.initializer.html),当然也可以直接传入一个定义好的[Parameter](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/mindspore/mindspore.Parameter.html#mindspore.Parameter)对象。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "22e379c0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:17.292108Z",
     "start_time": "2022-09-08T08:54:15.400077Z"
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import mindspore as ms\n",
    "from mindspore import nn\n",
    "\n",
    "# 定义模型\n",
    "class Network(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.layer1 = nn.SequentialCell([\n",
    "            nn.Conv2d(3, 12, kernel_size=3, pad_mode='pad', padding=1),\n",
    "            nn.BatchNorm2d(12),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        ])\n",
    "        self.layer2 = nn.SequentialCell([\n",
    "            nn.Conv2d(12, 4, kernel_size=3, pad_mode='pad', padding=1),\n",
    "            nn.BatchNorm2d(4),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        ])\n",
    "        self.pool = nn.AdaptiveMaxPool2d((5, 5))\n",
    "        self.fc = nn.Dense(100, 10)\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.layer1(x)\n",
    "        x = self.layer2(x)\n",
    "        x = self.pool(x)\n",
    "        x = x.view((-1, 100))\n",
    "        out = nn.Dense(x)\n",
    "        return out\n",
    "\n",
    "net = Network()\n",
    "for _, cell in net.cells_and_names():\n",
    "    if isinstance(cell, nn.Conv2d):\n",
    "        cell.weight.set_data(ms.common.initializer.initializer(\n",
    "            ms.common.initializer.HeNormal(negative_slope=0, mode='fan_out', nonlinearity='relu'),\n",
    "            cell.weight.shape, cell.weight.dtype))\n",
    "    elif isinstance(cell, (nn.BatchNorm2d, nn.GroupNorm)):\n",
    "        cell.gamma.set_data(ms.common.initializer.initializer(\"ones\", cell.gamma.shape, cell.gamma.dtype))\n",
    "        cell.beta.set_data(ms.common.initializer.initializer(\"zeros\", cell.beta.shape, cell.beta.dtype))\n",
    "    elif isinstance(cell, (nn.Dense)):\n",
    "        cell.weight.set_data(ms.common.initializer.initializer(\n",
    "            ms.common.initializer.HeUniform(negative_slope=math.sqrt(5)),\n",
    "            cell.weight.shape, cell.weight.dtype))\n",
    "        cell.bias.set_data(ms.common.initializer.initializer(\"zeros\", cell.bias.shape, cell.bias.dtype))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc33e12c",
   "metadata": {},
   "source": [
    "#### 参数冻结\n",
    "\n",
    "`Parameter`有一个`requires_grad`的属性来判断是否需要做参数更新,当`requires_grad=False`时相当于PyTorch的`buffer`对象。\n",
    "\n",
    "我们可以通过Cell的`parameters_dict`、`get_parameters`和`trainable_params`来获取`Cell`中的参数列表。\n",
    "\n",
    "- parameters_dict:获取网络结构中所有参数,返回一个以key为参数名,value为参数值的`OrderedDict`。\n",
    "\n",
    "- get_parameters:获取网络结构中的所有参数,返回`Cell`中`Parameter`的迭代器。\n",
    "\n",
    "- trainable_params:获取`Parameter`中`requires_grad`为`True`的属性,返回可训参数的列表。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "15e4bf43",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:17.301947Z",
     "start_time": "2022-09-08T08:54:17.294156Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Parameter (name=weight, shape=(1, 2), dtype=Float32, requires_grad=True), Parameter (name=bias, shape=(1,), dtype=Float32, requires_grad=True)]\n",
      "[Parameter (name=weight, shape=(1, 2), dtype=Float32, requires_grad=True)]\n"
     ]
    }
   ],
   "source": [
    "import mindspore.nn as nn\n",
    "\n",
    "net = nn.Dense(2, 1, has_bias=True)\n",
    "print(net.trainable_params())\n",
    "\n",
    "for param in net.trainable_params():\n",
    "    param_name = param.name\n",
    "    if \"bias\" in param_name:\n",
    "        param.requires_grad = False\n",
    "print(net.trainable_params())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b055d7b3",
   "metadata": {},
   "source": [
    "在定义优化器时,使用`net.trainable_params()`获取需要进行参数更新的参数列表。\n",
    "\n",
    "除了使用给参数设置`requires_grad=False`来不更新参数外,还可以使用`stop_gradient`来阻断梯度计算以达到冻结参数的作用。那什么时候使用`requires_grad=False`,什么时候使用`stop_gradient`呢?\n",
    "\n",
    "![parameter-freeze](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.1/docs/mindspore/source_zh_cn/migration_guide/model_development/images/parameter_freeze.png)\n",
    "\n",
    "如上图所示,`requires_grad=False`不更新部分参数,但是反向的梯度计算还是正常执行的;\n",
    "`stop_gradient`会直接截断反向梯度,当需要冻结的参数之前没有需要训练的参数时,两者在功能上是等价的。\n",
    "但是`stop_gradient`会更快(少执行了一部分反向梯度计算)。\n",
    "当冻结的参数之前有需要训练的参数时,只能使用`requires_grad=False`。\n",
    "另外,`stop_gradient`需要加在网络的计算链路里,作用的对象是Tensor:\n",
    "\n",
    "```python\n",
    "a = A(x)\n",
    "a = ops.stop_gradient(a)\n",
    "y = B(a)\n",
    "```\n",
    "\n",
    "#### 参数保存和加载\n",
    "\n",
    "MindSpore提供了`load_checkpoint`和`save_checkpoint`方法用来参数的保存和加载,需要注意的是参数保存时,保存的是参数列表,参数加载时对象必须是Cell。\n",
    "在参数加载时,可能参数名对不上需要做一些修改,可以直接构造一个新的参数列表给到`load_checkpoint`加载到Cell。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e0a84660",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:17.340205Z",
     "start_time": "2022-09-08T08:54:17.304020Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "weight [[-0.0042482  -0.00427286]]\n",
      "bias [0.]\n",
      "{'weight': Parameter (name=weight, shape=(1, 2), dtype=Float32, requires_grad=True), 'bias': Parameter (name=bias, shape=(1,), dtype=Float32, requires_grad=True)}\n",
      "weight [[-0.0042482  -0.00427286]]\n",
      "bias [0.]\n",
      "weight [[1. 1.]]\n",
      "bias [1.]\n"
     ]
    }
   ],
   "source": [
    "import mindspore as ms\n",
    "import mindspore.ops as ops\n",
    "import mindspore.nn as nn\n",
    "\n",
    "net = nn.Dense(2, 1, has_bias=True)\n",
    "for param in net.get_parameters():\n",
    "    print(param.name, param.data.asnumpy())\n",
    "\n",
    "ms.save_checkpoint(net, \"dense.ckpt\")\n",
    "dense_params = ms.load_checkpoint(\"dense.ckpt\")\n",
    "print(dense_params)\n",
    "new_params = {}\n",
    "for param_name in dense_params:\n",
    "    print(param_name, dense_params[param_name].data.asnumpy())\n",
    "    new_params[param_name] = ms.Parameter(ops.ones_like(dense_params[param_name].data), name=param_name)\n",
    "\n",
    "ms.load_param_into_net(net, new_params)\n",
    "for param in net.get_parameters():\n",
    "    print(param.name, param.data.asnumpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b9011e7",
   "metadata": {},
   "source": [
    "### 动态图与静态图\n",
    "\n",
    "对于`Cell`,MindSpore提供`GRAPH_MODE`(静态图)和`PYNATIVE_MODE`(动态图)两种模式,详情请参考[动态图和静态图](https://www.mindspore.cn/tutorials/zh-CN/r2.1/advanced/compute_graph.html)。\n",
    "\n",
    "`PyNative`模式下模型进行**推理**的行为与一般Python代码无异。但是在训练过程中,注意**一旦将Tensor转换成numpy做其他的运算后将会截断网络的梯度,相当于PyTorch的detach**。\n",
    "\n",
    "而在使用`GRAPH_MODE`时,通常会出现语法限制。在这种情况下,需要对Python代码进行图编译操作,而这一步操作中MindSpore目前还未能支持完整的Python语法全集,所以`construct`函数的编写会存在部分限制。具体限制内容可以参考[MindSpore静态图语法](https://www.mindspore.cn/docs/zh-CN/r2.1/note/static_graph_syntax_support.html)。\n",
    "\n",
    "#### 常见限制\n",
    "\n",
    "相较于详细的语法说明,常见的限制可以归结为以下几点:\n",
    "\n",
    "- 场景1\n",
    "\n",
    "    限制:构图时(`construct`函数部分或者用`@jit`修饰的函数),不要调用其他Python库,例如numpy、scipy,相关的处理应该前移到`__init__`阶段。\n",
    "    措施:使用MindSpore内部提供的API替换其他Python库的功能。常量的处理可以前移到`__init__`阶段。\n",
    "\n",
    "- 场景2\n",
    "\n",
    "    限制:构图时不要使用自定义类型,而应该使用MindSpore提供的数据类型和Python基础类型,可以使用基于这些类型的tuple/list组合。\n",
    "    措施:使用基础类型进行组合,可以考虑增加函数参数量。函数入参数没有限制,并且可以使用不定长输入。\n",
    "\n",
    "- 场景3\n",
    "\n",
    "    限制:构图时不要对数据进行多线程或多进程处理。\n",
    "    措施:避免网络中出现多线程处理。\n",
    "\n",
    "### 自定义反向\n",
    "\n",
    "但是有的时候MindSpore不支持某些处理,需要使用一些三方的库的方法,但是我们又不想截断网络的梯度,这时该怎么办呢?这里介绍一种在`PYNATIVE_MODE`模式下,通过自定义反向规避此问题的方法:\n",
    "\n",
    "有这么一个场景,需要随机有放回的选取大于0.5的值,且每个batch的shape固定是max_num。但是这个随机有放回的操作目前没有MindSpore的API支持,这时我们在`PYNATIVE_MODE`下使用numpy的方法来计算,然后自己构造一个梯度传播的过程。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e2113f92",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:17.467041Z",
     "start_time": "2022-09-08T08:54:17.342271Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x [[1.2510660e+00 2.1609735e+00 3.4312444e-04 9.0699774e-01 4.4026768e-01]\n",
      " [2.7701578e-01 5.5878061e-01 1.0366821e+00 1.1903024e+00 1.6164502e+00]]\n",
      "pos_values forword [[0.90699774 2.1609735  0.90699774]\n",
      " [0.5587806  1.6164502  0.5587806 ]]\n",
      "pos_indices forword [[3 1 3]\n",
      " [1 4 1]]\n",
      "pos_values forword [[0.90699774 1.251066   2.1609735 ]\n",
      " [1.1903024  1.1903024  0.5587806 ]]\n",
      "pos_indices forword [[3 0 1]\n",
      " [3 3 1]]\n",
      "dx (Tensor(shape=[2, 5], dtype=Float32, value=\n",
      "[[0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 0.00000000e+000],\n",
      " [0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 0.00000000e+000]]),)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "import mindspore.ops as ops\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)\n",
    "ms.set_seed(1)\n",
    "\n",
    "class MySampler(nn.Cell):\n",
    "    # 自定义取样器,在每个batch选取max_num个大于0.5的值\n",
    "    def __init__(self, max_num):\n",
    "        super(MySampler, self).__init__()\n",
    "        self.max_num = max_num\n",
    "\n",
    "    def random_positive(self, x):\n",
    "        # 三方库numpy的方法,选取大于0.5的位置\n",
    "        pos = np.where(x > 0.5)[0]\n",
    "        pos_indice = np.random.choice(pos, self.max_num)\n",
    "        return pos_indice\n",
    "\n",
    "    def construct(self, x):\n",
    "        # 正向网络构造\n",
    "        batch = x.shape[0]\n",
    "        pos_value = []\n",
    "        pos_indice = []\n",
    "        for i in range(batch):\n",
    "            a = x[i].asnumpy()\n",
    "            pos_ind = self.random_positive(a)\n",
    "            pos_value.append(ms.Tensor(a[pos_ind], ms.float32))\n",
    "            pos_indice.append(ms.Tensor(pos_ind, ms.int32))\n",
    "        pos_values = ops.stack(pos_value, axis=0)\n",
    "        pos_indices = ops.stack(pos_indice, axis=0)\n",
    "        print(\"pos_values forword\", pos_values)\n",
    "        print(\"pos_indices forword\", pos_indices)\n",
    "        return pos_values, pos_indices\n",
    "\n",
    "x = ms.Tensor(np.random.uniform(0, 3, (2, 5)), ms.float32)\n",
    "print(\"x\", x)\n",
    "sampler = MySampler(3)\n",
    "pos_values, pos_indices = sampler(x)\n",
    "grad = ms.grad(sampler, grad_position=0)(x)\n",
    "print(\"dx\", grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9f1e0b1",
   "metadata": {},
   "source": [
    "当我们不构造这个反向过程时,由于使用的是numpy的方法计算的`pos_value`,梯度将会截断。\n",
    "如上面注释所示,`dx`的值全是0。另外细心的同学会发现这个过程打印了两次`pos_values forword`和`pos_indices forword`,这是因为在`PYNATIVE_MODE`下在构造反向图时会再次构造一次正向图,这使得上面的这种写法实际上跑了两次正向和一次反向,这不但浪费了训练资源,在某些情况还会造成精度问题,如有BatchNorm的情况,在运行正向时就会更新`moving_mean`和`moving_var`导致一次训练更新了两次`moving_mean`和`moving_var`。\n",
    "为了避免这种场景,MindSpore针对`Cell`有一个方法`set_grad()`,在`PYNATIVE_MODE`模式下框架会在构造正向时同步构造反向,这样在执行反向时就不会再运行正向的流程了。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "efc96a1b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:17.489786Z",
     "start_time": "2022-09-08T08:54:17.469116Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x [[1.2519144  1.6760695  0.42116082 0.59430444 2.4022336 ]\n",
      " [2.9047847  0.9402725  2.076968   2.6291676  2.68382   ]]\n",
      "pos_values forword [[1.2519144 1.2519144 1.6760695]\n",
      " [2.6291676 2.076968  0.9402725]]\n",
      "pos_indices forword [[0 0 1]\n",
      " [3 2 1]]\n",
      "dx (Tensor(shape=[2, 5], dtype=Float32, value=\n",
      "[[0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 0.00000000e+000],\n",
      " [0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 0.00000000e+000]]),)\n"
     ]
    }
   ],
   "source": [
    "x = ms.Tensor(np.random.uniform(0, 3, (2, 5)), ms.float32)\n",
    "print(\"x\", x)\n",
    "sampler = MySampler(3).set_grad()\n",
    "pos_values, pos_indices = sampler(x)\n",
    "grad = ms.grad(sampler, grad_position=0)(x)\n",
    "print(\"dx\", grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fcd6ef6f",
   "metadata": {},
   "source": [
    "下面,我们来演示下如何[自定义反向](https://mindspore.cn/tutorials/zh-CN/r2.1/advanced/modules/layer.html#自定义cell反向):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8445582c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:17.526921Z",
     "start_time": "2022-09-08T08:54:17.494458Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x [[1.2510660e+00 2.1609735e+00 3.4312444e-04 9.0699774e-01 4.4026768e-01]\n",
      " [2.7701578e-01 5.5878061e-01 1.0366821e+00 1.1903024e+00 1.6164502e+00]]\n",
      "pos_values forword [[0.90699774 2.1609735  0.90699774]\n",
      " [0.5587806  1.6164502  0.5587806 ]]\n",
      "pos_indices forword [[3 1 3]\n",
      " [1 4 1]]\n",
      "pos_indices backward [[3 1 3]\n",
      " [1 4 1]]\n",
      "grad_x backward [[1. 1. 1.]\n",
      " [1. 1. 1.]]\n",
      "dx (Tensor(shape=[2, 5], dtype=Float32, value=\n",
      "[[0.00000000e+000, 1.00000000e+000, 0.00000000e+000, 2.00000000e+000, 0.00000000e+000],\n",
      " [0.00000000e+000, 2.00000000e+000, 0.00000000e+000, 0.00000000e+000, 1.00000000e+000]]),)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "import mindspore.ops as ops\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)\n",
    "ms.set_seed(1)\n",
    "\n",
    "class MySampler(nn.Cell):\n",
    "    # 自定义取样器,在每个batch选取max_num个大于0.5的值\n",
    "    def __init__(self, max_num):\n",
    "        super(MySampler, self).__init__()\n",
    "        self.max_num = max_num\n",
    "\n",
    "    def random_positive(self, x):\n",
    "        # 三方库numpy的方法,选取大于0.5的位置\n",
    "        pos = np.where(x > 0.5)[0]\n",
    "        pos_indice = np.random.choice(pos, self.max_num)\n",
    "        return pos_indice\n",
    "\n",
    "    def construct(self, x):\n",
    "        # 正向网络构造\n",
    "        batch = x.shape[0]\n",
    "        pos_value = []\n",
    "        pos_indice = []\n",
    "        for i in range(batch):\n",
    "            a = x[i].asnumpy()\n",
    "            pos_ind = self.random_positive(a)\n",
    "            pos_value.append(ms.Tensor(a[pos_ind], ms.float32))\n",
    "            pos_indice.append(ms.Tensor(pos_ind, ms.int32))\n",
    "        pos_values = ops.stack(pos_value, axis=0)\n",
    "        pos_indices = ops.stack(pos_indice, axis=0)\n",
    "        print(\"pos_values forword\", pos_values)\n",
    "        print(\"pos_indices forword\", pos_indices)\n",
    "        return pos_values, pos_indices\n",
    "\n",
    "    def bprop(self, x, out, dout):\n",
    "        # 反向网络构造\n",
    "        pos_indices = out[1]\n",
    "        print(\"pos_indices backward\", pos_indices)\n",
    "        grad_x = dout[0]\n",
    "        print(\"grad_x backward\", grad_x)\n",
    "        batch = x.shape[0]\n",
    "        dx = []\n",
    "        for i in range(batch):\n",
    "            dx.append(ops.UnsortedSegmentSum()(grad_x[i], pos_indices[i], x.shape[1]))\n",
    "        return ops.stack(dx, axis=0)\n",
    "\n",
    "x = ms.Tensor(np.random.uniform(0, 3, (2, 5)), ms.float32)\n",
    "print(\"x\", x)\n",
    "sampler = MySampler(3).set_grad()\n",
    "pos_values, pos_indices = sampler(x)\n",
    "grad = ms.grad(sampler, grad_position=0)(x)\n",
    "print(\"dx\", grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10166131",
   "metadata": {},
   "source": [
    "我们在`MySampler`类里加入了`bprop`方法,这个方法的输入是正向的输入(展开写),正向的输出(一个tuple),输出的梯度(一个tuple)。在这个方法里构造梯度到输入的梯度反传流程。\n",
    "可以看到在第0个batch,我们随机选取第3、1、3位置的值,输出的梯度都是1,最后反传出去的梯度为`[0.00000000e+000, 1.00000000e+000, 0.00000000e+000, 2.00000000e+000, 0.00000000e+000]`,符合预期。\n",
    "\n",
    "### 动态shape规避策略\n",
    "\n",
    "一般动态shape引入的原因有:\n",
    "\n",
    "- 输入shape不固定;\n",
    "- 网络执行过程中有引发shape变化的算子;\n",
    "- 控制流不同分支引入shape上的变化。\n",
    "\n",
    "下面,我们针对这几种场景介绍一些规避策略。\n",
    "\n",
    "#### 输入shape不固定的场景\n",
    "\n",
    "1. 可以在输入数据上加pad,pad到固定的shape。如deep_speechv2的[数据处理](https://gitee.com/mindspore/models/blob/r2.1/official/audio/DeepSpeech2/src/dataset.py#L153) 规定`input_length`的最大长度,短的补0,长的随机截断,但是注意这种方法可能会影响训练的精度,需要平衡训练精度和训练性能。\n",
    "\n",
    "2. 可以设置一组固定的输入shape,将输入分别处理成几个固定的尺度。如YOLOv3_darknet53的[数据处理](https://gitee.com/mindspore/models/blob/r2.1/official/cv/YOLOv3/src/yolo_dataset.py#L177),在batch方法加处理函数`multi_scale_trans`,在其中在[MultiScaleTrans](https://gitee.com/mindspore/models/blob/r2.1/official/cv/YOLOv3/src/transforms.py#L456)中随机选取一个shape进行处理。\n",
    "\n",
    "目前对输入shape完全随机的情况支持有限,需要等待新版本支持。\n",
    "\n",
    "#### 网络执行过程中有引发shape变化的操作\n",
    "\n",
    "对于网络运行过程中生成不固定shape的Tensor的场景,最常用的方式是构造mask来过滤掉无效的位置的值。一个简单的例子,在检测场景下需要根据预测框和真实框的iou结果选取一些框。\n",
    "PyTorch的实现方式如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "fe0dbe16",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:17.534666Z",
     "start_time": "2022-09-08T08:54:17.528497Z"
    }
   },
   "outputs": [],
   "source": [
    "def box_select_torch(box, iou_score):\n",
    "    mask = iou_score > 0.3\n",
    "    return box[mask]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5bef9a24",
   "metadata": {},
   "source": [
    "当前MindSpore1.8之后全场景支持了masked_select,在MindSpore上可以这样实现:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "497a7593",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:17.552150Z",
     "start_time": "2022-09-08T08:54:17.536221Z"
    }
   },
   "outputs": [],
   "source": [
    "import mindspore as ms\n",
    "from mindspore import ops\n",
    "\n",
    "ms.set_seed(1)\n",
    "\n",
    "def box_select_ms(box, iou_score):\n",
    "    mask = (iou_score > 0.3).expand_dims(1)\n",
    "    return ops.masked_select(box, mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ba0d80d",
   "metadata": {},
   "source": [
    "看一下结果对比:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5a21a5f8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:17.570773Z",
     "start_time": "2022-09-08T08:54:17.554235Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "box_select_ms [0.14675589 0.09233859 0.18626021 0.34556073]\n",
      "box_select_torch tensor([[0.1468, 0.0923, 0.1863, 0.3456]])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import mindspore as ms\n",
    "\n",
    "ms.set_seed(1)\n",
    "\n",
    "box = np.random.uniform(0, 1, (3, 4)).astype(np.float32)\n",
    "iou_score = np.random.uniform(0, 1, (3,)).astype(np.float32)\n",
    "\n",
    "print(\"box_select_ms\", box_select_ms(ms.Tensor(box), ms.Tensor(iou_score)))\n",
    "print(\"box_select_torch\", box_select_torch(torch.from_numpy(box), torch.from_numpy(iou_score)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "409376eb",
   "metadata": {},
   "source": [
    "但是这样操作后会产生动态shape,在后续的网络计算中可能会有问题,在现阶段,推荐先使用mask规避一下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "01e92adc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:17.576516Z",
     "start_time": "2022-09-08T08:54:17.572354Z"
    }
   },
   "outputs": [],
   "source": [
    "def box_select_ms2(box, iou_score):\n",
    "    mask = (iou_score > 0.3).expand_dims(1)\n",
    "    return box * mask, mask"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6316ef76",
   "metadata": {},
   "source": [
    "在后续计算中,如果涉及box的一些操作,需要注意是否需要乘mask用来过滤非有效结果。\n",
    "\n",
    "对于求loss时对feature做选取,导致获取到不固定shape的Tensor的场景,处理方式基本和网络运行过程中不固定shape的处理方式相同,只是loss部分后续可能没有其他的操作,不需要返回mask。\n",
    "\n",
    "举个例子,我们想选取前70%的正样本的值求loss。\n",
    "PyTorch的实现如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "07c8deb4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:17.599729Z",
     "start_time": "2022-09-08T08:54:17.578583Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as torch_nn\n",
    "\n",
    "class ClassLoss_pt(torch_nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ClassLoss_pt, self).__init__()\n",
    "        self.con_loss = torch_nn.CrossEntropyLoss(reduction='none')\n",
    "\n",
    "    def forward(self, pred, label):\n",
    "        mask = label > 0\n",
    "        vaild_label = label * mask\n",
    "        pos_num = torch.clamp(mask.sum() * 0.7, 1).int()\n",
    "        con = self.con_loss(pred, vaild_label.long()) * mask\n",
    "        loss, _ = torch.topk(con, k=pos_num)\n",
    "        return loss.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "352a8dc1",
   "metadata": {},
   "source": [
    "在里面使用了`torch.topk`来获取前70%的正样本数据,在MindSpore里目前不支持TopK的K是变量,所以需要转换下思路,获取到第K大的值,然后通过这个值获取到topk的mask,MindSpore的实现方式如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "70f5e64b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:17.614682Z",
     "start_time": "2022-09-08T08:54:17.601805Z"
    }
   },
   "outputs": [],
   "source": [
    "import mindspore as ms\n",
    "from mindspore import ops\n",
    "from mindspore import nn as ms_nn\n",
    "\n",
    "class ClassLoss_ms(ms_nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(ClassLoss_ms, self).__init__()\n",
    "        self.con_loss = ms_nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction=\"none\")\n",
    "        self.sort_descending = ops.Sort(descending=True)\n",
    "\n",
    "    def construct(self, pred, label):\n",
    "        mask = label > 0\n",
    "        vaild_label = label * mask\n",
    "        pos_num = ops.maximum(mask.sum() * 0.7, 1).astype(ms.int32)\n",
    "        con = self.con_loss(pred, vaild_label.astype(ms.int32)) * mask\n",
    "        con_sort, _ = self.sort_descending(con)\n",
    "        con_k = con_sort[pos_num - 1]\n",
    "        con_mask = (con >= con_k).astype(con.dtype)\n",
    "        loss = con * con_mask\n",
    "        return loss.sum() / con_mask.sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f071f6f8",
   "metadata": {},
   "source": [
    "我们来看一下实验结果:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "668fb076",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:17.663461Z",
     "start_time": "2022-09-08T08:54:17.616723Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pred [[4.17021990e-01 7.20324516e-01]\n",
      " [1.14374816e-04 3.02332580e-01]\n",
      " [1.46755889e-01 9.23385918e-02]\n",
      " [1.86260208e-01 3.45560730e-01]\n",
      " [3.96767467e-01 5.38816750e-01]]\n",
      "label [-1  0  1  1  0]\n",
      "cls_loss_pt tensor(0.7207)\n",
      "cls_loss_ms 0.7207259\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import mindspore as ms\n",
    "ms.set_seed(1)\n",
    "\n",
    "pred = np.random.uniform(0, 1, (5, 2)).astype(np.float32)\n",
    "label = np.array([-1, 0, 1, 1, 0]).astype(np.int32)\n",
    "print(\"pred\", pred)\n",
    "print(\"label\", label)\n",
    "t_loss = ClassLoss_pt()\n",
    "cls_loss_pt = t_loss(torch.from_numpy(pred), torch.from_numpy(label))\n",
    "print(\"cls_loss_pt\", cls_loss_pt)\n",
    "m_loss = ClassLoss_ms()\n",
    "cls_loss_ms = m_loss(ms.Tensor(pred), ms.Tensor(label))\n",
    "print(\"cls_loss_ms\", cls_loss_ms)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23e71621",
   "metadata": {},
   "source": [
    "#### 控制流不同分支引入shape上的变化\n",
    "\n",
    "分析下在模型分析与准备章节的例子:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "e3b02279",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:17.680070Z",
     "start_time": "2022-09-08T08:54:17.666087Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4.17021990e-01 7.20324516e-01 1.14374816e-04 3.02332580e-01\n",
      " 1.46755889e-01 9.23385918e-02 1.86260208e-01 3.45560730e-01\n",
      " 3.96767467e-01 5.38816750e-01]\n",
      "True\n",
      "[0.7203245  0.53881675]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "from mindspore import ops\n",
    "np.random.seed(1)\n",
    "x = ms.Tensor(np.random.uniform(0, 1, (10)).astype(np.float32))\n",
    "cond = (x > 0.5).any()\n",
    "\n",
    "if cond:\n",
    "    y = ops.masked_select(x, x > 0.5)\n",
    "else:\n",
    "    y = ops.zeros_like(x)\n",
    "print(x)\n",
    "print(cond)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "970db4c4",
   "metadata": {},
   "source": [
    "在`cond=True`的时最大的shape和x一样大,根据上面的加mask方法,可以写成:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "aef1a222",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T08:54:17.708442Z",
     "start_time": "2022-09-08T08:54:17.682136Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4.17021990e-01 7.20324516e-01 1.14374816e-04 3.02332580e-01\n",
      " 1.46755889e-01 9.23385918e-02 1.86260208e-01 3.45560730e-01\n",
      " 3.96767467e-01 5.38816750e-01]\n",
      "True\n",
      "[0.         0.7203245  0.         0.         0.         0.\n",
      " 0.         0.         0.         0.53881675]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "from mindspore import ops\n",
    "np.random.seed(1)\n",
    "x = ms.Tensor(np.random.uniform(0, 1, (10)).astype(np.float32))\n",
    "cond = (x > 0.5).any()\n",
    "\n",
    "if cond:\n",
    "    mask = (x > 0.5).astype(x.dtype)\n",
    "else:\n",
    "    mask = ops.zeros_like(x)\n",
    "y = x * mask\n",
    "print(x)\n",
    "print(cond)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cade2ad2",
   "metadata": {},
   "source": [
    "需要注意的是如果y在后续有参与其他的计算,需要一起传入mask对有效位置做过滤。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}