{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 自动向量化Vmap\n",
    "\n",
    "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3.0rc2/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.3.0rc2/tutorials/experts/zh_cn/vmap/mindspore_vmap.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3.0rc2/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.3.0rc2/tutorials/experts/zh_cn/vmap/mindspore_vmap.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3.0rc2/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.3.0rc2/tutorials/experts/source_zh_cn/vmap/vmap.ipynb)\n",
    "\n",
    "## 概述\n",
    "\n",
    "AI融合计算的蓬勃发展,对框架能力提出了新的需求和挑战。问题场景和模型设计逐渐复杂化,使得业务数据的维度和计算逻辑的嵌套深度也相应增长。结合向量化优化手段可以有效优化性能瓶颈,但实现向量化优化对于普通用户并非易事。虽然用户可以很容易地实现低维数据运算逻辑,但随着数据维度的增长,业务逻辑也变得更为复杂,用户需要清晰了解各操作间的数据维度的逻辑映射关系,给用户的模型设计和编码带来了巨大挑战。自动向量化特性(Vmap)帮助用户解决了这个头疼的问题,该技术允许用户将特定的批处理逻辑从函数中剥离。用户在编写函数时,只需要先考虑低维的运算逻辑即可,通过调用`vmap`接口自动实现高维运算,并且支持嵌套调用,有效降低问题复杂度。\n",
    "\n",
    "本教程介绍自动向量化`vmap`接口的使用方式,将模型或函数中高度重复的运算逻辑转换为并行的向量运算逻辑,从而获得更加精简的代码逻辑以及更高效的执行性能。\n",
    "\n",
    "## 向量化思维\n",
    "\n",
    "向量化思维在提升计算性能的技术中十分常见,向量化思维可公式化表示为:\n",
    "\n",
    "$$\\begin{matrix}\n",
    "a_{1} + b_{1} = c_{1} \\\\\n",
    "a_{2} + b_{2} = c_{2} \\\\\n",
    "a_{3} + b_{3} = c_{3} \\\\\n",
    "a_{4} + b_{4} = c_{4}\n",
    "\\end{matrix} \\Rightarrow \\vec{a} + \\vec{b} = \\vec{c}$$\n",
    "\n",
    "其核心思想在于将多次for循环的计算逻辑转换为一次对于向量的计算。如果将单个操作抽象为一个函数或者一个模型的操作集合,同样可应用向量化思维方式来处理。\n",
    "\n",
    "## 手动向量化\n",
    "\n",
    "首先,我们先构造一个简单的卷积函数,适用于一维向量场景:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tensor(shape=[3], dtype=Float32, value= [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01])"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import mindspore\n",
    "from mindspore import Tensor, ops\n",
    "import mindspore.numpy as mnp\n",
    "\n",
    "x = mnp.arange(5).astype('float32')\n",
    "w = mnp.array([1., 2., 3.])\n",
    "\n",
    "def convolve(x, w):\n",
    "    output = []\n",
    "    for i in range(1, len(x) - 1):\n",
    "        output.append(mnp.dot(x[i - 1 : i + 2], w))\n",
    "    return mnp.stack(output)\n",
    "\n",
    "convolve(x, w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "当我们期望该函数运用于计算一批一维的卷积运算时,我们很自然地会想到调用for循环进行批处理:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tensor(shape=[3, 3], dtype=Float32, value=\n",
       "[[ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],\n",
       " [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],\n",
       " [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01]])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_batch = mnp.stack([x, x, x])\n",
    "w_batch = mnp.stack([w, w, w])\n",
    "\n",
    "def manually_batch_conv(x_batch, w_batch):\n",
    "    output = []\n",
    "    for i in range(x_batch.shape[0]):\n",
    "        output.append(convolve(x_batch[i], w_batch[i]))\n",
    "    return mnp.stack(output)\n",
    "\n",
    "manually_batch_conv(x_batch, w_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "很显然,通过这种实现方式我们能够得到正确的计算结果,但效率并不高。\n",
    "当然,您也可以通过自己手动重写函数实现更高效率的向量化计算逻辑,但这将涉及对数据的索引、轴等信息的处理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tensor(shape=[3, 3], dtype=Float32, value=\n",
       "[[ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],\n",
       " [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],\n",
       " [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def manually_vectorization_conv(x_batch, w_batch):\n",
    "    output = []\n",
    "    for i in range(1, x_batch.shape[-1] - 1):\n",
    "        output.append(mnp.sum(x_batch[:, i - 1 : i + 2] * w_batch, axis=1))\n",
    "    return mnp.stack(output, axis=1)\n",
    "\n",
    "manually_vectorization_conv(x_batch, w_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在低维场景下,您可以很容易把握数据索引间的映射逻辑,但随着维度的增加,计算逻辑也变得更为复杂,您或许也会为此混乱的逻辑感到头疼。\n",
    "幸运的是Vmap为我们提供了另一种实现方式。\n",
    "\n",
    "## 自动向量化\n",
    "\n",
    "Vmap可以帮助我们隐藏批处理维度,您只需要调用一个接口便可以将函数转换为向量化形式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tensor(shape=[3, 3], dtype=Float32, value=\n",
       "[[ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],\n",
       " [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],\n",
       " [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from mindspore import vmap\n",
    "\n",
    "auto_vectorization_conv = vmap(convolve)\n",
    "auto_vectorization_conv(x_batch, w_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Vmap除了为您提供了简易的编程体验外,将循环逻辑下沉至函数的各个基元操作中,结合分布式并行优化以获得更高的执行性能。\n",
    "默认情况下,`vmap`的输入输出沿第一个轴进行批处理,如果您的输入和输出并不总是期望沿0轴批处理,可以通过`in_axes`和`out_axes`参数进行指定。您可以为所有输入或输出位置分别指定批处理轴索引,也可以为所有输入或输出指定相同的批处理轴索引。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tensor(shape=[3, 3], dtype=Float32, value=\n",
       "[[ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],\n",
       " [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],\n",
       " [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w_batch_t = ops.transpose(w_batch, (1, 0))\n",
    "\n",
    "auto_vectorization_conv = vmap(convolve, in_axes=(0, 1), out_axes=1)\n",
    "output = auto_vectorization_conv(x_batch, w_batch_t)\n",
    "\n",
    "ops.transpose(output, (1, 0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对于多个输入的场景,您还可以指定只对其中的某些入参进行批处理,如上述场景变为求一组一维向量与某一权重的卷积,可在`in_axes`参数中的输入对应位置配置`None`即可,`None`表示不沿任何轴进行批处理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tensor(shape=[3, 3], dtype=Float32, value=\n",
       "[[ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],\n",
       " [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],\n",
       " [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "auto_vectorization_conv = vmap(convolve, in_axes=(0, None), out_axes=0)\n",
    "auto_vectorization_conv(x_batch, w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 为保证自动向量化计算逻辑的正确性,vmap内部会根据输入的维度和轴索引以及批处理大小等信息进行校验,详细参数限制可参考[mindspore.vmap](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/api_python/mindspore/mindspore.vmap.html#mindspore.vmap)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 高阶函数的嵌套\n",
    "\n",
    "Vmap本质上是一种高阶函数,它将函数作为输入,并返回可应用于批处理数据的向量化函数。用法上它允许和其他框架提供的高阶函数进行嵌套组合使用。\n",
    "\n",
    "- `vmap`与`vmap`嵌套使用,应用于两层以上的批处理逻辑。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tensor(shape=[3, 3, 3], dtype=Float32, value=\n",
       "[[[ 6.00000000e+00,  9.00000000e+00,  1.20000000e+01],\n",
       "  [ 1.20000000e+01,  1.80000000e+01,  2.40000000e+01],\n",
       "  [ 1.80000000e+01,  2.70000000e+01,  3.60000000e+01]],\n",
       " [[ 9.00000000e+00,  1.20000000e+01,  1.50000000e+01],\n",
       "  [ 1.80000000e+01,  2.40000000e+01,  3.00000000e+01],\n",
       "  [ 2.70000000e+01,  3.60000000e+01,  4.50000000e+01]],\n",
       " [[ 1.20000000e+01,  1.50000000e+01,  1.80000000e+01],\n",
       "  [ 2.40000000e+01,  3.00000000e+01,  3.60000000e+01],\n",
       "  [ 3.60000000e+01,  4.50000000e+01,  5.40000000e+01]]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hyper_x = Tensor([[1., 2., 3., 4., 5.], [2., 3., 4., 5., 6.], [3., 4., 5., 6., 7.]], mindspore.float32)\n",
    "hyper_w = Tensor([[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]], mindspore.float32)\n",
    "\n",
    "hyper_vmap_ger = vmap(vmap(convolve, in_axes=[None, 0]), in_axes=[0, None])\n",
    "hyper_vmap_ger(hyper_x, hyper_w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "source": [
    "- `grad`内部嵌套`vmap`使用,应用于计算向量化函数的梯度等场景。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Tensor(shape=[2, 3], dtype=Float32, value=\n",
       " [[ 2.83662200e-01, -1.45500034e-01,  4.42569796e-03],\n",
       "  [-1.45500034e-01,  4.42569796e-03,  1.36737213e-01]]),\n",
       " Tensor(shape=[2, 3], dtype=Float32, value=\n",
       " [[ 5.67324400e-01, -2.91000068e-01,  8.85139592e-03],\n",
       "  [-2.91000068e-01,  8.85139592e-03,  2.73474425e-01]]))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from mindspore import grad\n",
    "\n",
    "def forward_fn(x, y):\n",
    "    out = x + 2 * y\n",
    "    out = ops.sin(out)\n",
    "    reduce_sum = ops.ReduceSum()\n",
    "    return reduce_sum(out)\n",
    "\n",
    "x_hat = Tensor([[1., 2., 3.], [2., 3., 4.]], mindspore.float32)\n",
    "y_hat = Tensor([[2., 3., 4.], [3., 4., 5.]], mindspore.float32)\n",
    "\n",
    "grad_vmap_ger = grad(vmap(forward_fn), grad_position=(0, 1))\n",
    "grad_vmap_ger(x_hat, y_hat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- `vmap`内部嵌套`grad`使用,应用于计算批量梯度、高阶梯度计算等场景,如计算Jacobian矩阵。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Tensor(shape=[2, 3], dtype=Float32, value=\n",
       " [[ 2.83662200e-01, -1.45500034e-01,  4.42569796e-03],\n",
       "  [-1.45500034e-01,  4.42569796e-03,  1.36737213e-01]]),\n",
       " Tensor(shape=[2, 3], dtype=Float32, value=\n",
       " [[ 5.67324400e-01, -2.91000068e-01,  8.85139592e-03],\n",
       "  [-2.91000068e-01,  8.85139592e-03,  2.73474425e-01]]))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vmap_grad_ger = vmap(grad(forward_fn, grad_position=(0, 1)))\n",
    "vmap_grad_ger(x_hat, y_hat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本教程中只简单介绍两层高阶函数组合嵌套的用法,您可以根据场景需求实现更多层次的嵌套。\n",
    "\n",
    "### Cell的自动向量化\n",
    "\n",
    "之前的用例我们都是以函数对象作为输入,下面将介绍`Cell`对象结合`vmap`的用法。这是一个简单定义的全连接层的例子。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.0219292  -0.01062493 -0.03378957 -0.02589925]\n",
      " [ 0.03091274 -0.04968021 -0.08098207 -0.07896652]]\n",
      "[[ 0.02492371 -0.02364336 -0.0495204  -0.04358834]\n",
      " [ 0.03390725 -0.06269865 -0.09671289 -0.09665561]]\n",
      "[[ 0.02791822 -0.03666179 -0.06525123 -0.06127743]\n",
      " [ 0.03690176 -0.07571708 -0.11244373 -0.1143447 ]]\n",
      "[[[ 0.0219292  -0.01062493 -0.03378957 -0.02589925]\n",
      "  [ 0.03091274 -0.04968021 -0.08098207 -0.07896652]]\n",
      "\n",
      " [[ 0.02492371 -0.02364336 -0.0495204  -0.04358834]\n",
      "  [ 0.03390725 -0.06269865 -0.09671289 -0.09665561]]\n",
      "\n",
      " [[ 0.02791822 -0.03666179 -0.06525123 -0.06127743]\n",
      "  [ 0.03690176 -0.07571708 -0.11244373 -0.1143447 ]]]\n"
     ]
    }
   ],
   "source": [
    "import mindspore.nn as nn\n",
    "from mindspore import Parameter\n",
    "from mindspore.common.initializer import initializer\n",
    "\n",
    "class Dense(nn.Cell):\n",
    "    def __init__(self, in_channels, out_channels, weight_init='normal', bias_init='zeros'):\n",
    "        super(Dense, self).__init__()\n",
    "        self.scalar = 1\n",
    "        self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name=\"weight\")\n",
    "        self.bias = Parameter(initializer(bias_init, [out_channels]), name=\"bias\")\n",
    "        self.matmul = ops.MatMul(transpose_b=True)\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.matmul(x, self.weight)\n",
    "        output = ops.bias_add(x, self.bias)\n",
    "        return output\n",
    "\n",
    "input_a = Tensor([[1, 2, 3], [4, 5, 6]], mindspore.float32)\n",
    "input_b = Tensor([[2, 3, 4], [5, 6, 7]], mindspore.float32)\n",
    "input_c = Tensor([[3, 4, 5], [6, 7, 8]], mindspore.float32)\n",
    "\n",
    "dense_net = Dense(3, 4)\n",
    "print(dense_net(input_a))\n",
    "print(dense_net(input_b))\n",
    "print(dense_net(input_c))\n",
    "\n",
    "inputs = mnp.stack([input_a, input_b, input_c])\n",
    "\n",
    "vmap_dense_net = vmap(dense_net)\n",
    "print(vmap_dense_net(inputs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Cell`和函数式的自动向量化用法基本一致,只需要将`vmap`的第一个入参替换为`Cell`实例即可,Vmap将`construct`转换为作用于批处理数据的向量化`construct`。另外,该用例中初始化函数定义了两个`Parameter`参数, Vmap对于这类执行函数的自由变量的处理等同于将其作为入参并配置对应`in_axes`位置为`None`的场景。\n",
    "\n",
    "通过这种方式,我们可以实现批量输入在同一个模型上进行训练或推理等功能,与现有网络模型输入支持batch轴输入的区别在于,利用Vmap实现的批处理维度更加灵活,不局限于`NCHW`等输入格式。\n",
    "\n",
    "### 模型集成场景\n",
    "\n",
    "模型集成场景将来自多个模型的预测结果组合在一起,传统的实现方式是通过分别在某些输入上运行各个模型,然后将各自的预测结果组合在一起。假如您正在运行的是具有相同架构的模型,那么您可以借助Vmap将它们进行向量化,从而实现加速效果。\n",
    "\n",
    "该场景下涉及权重数据的向量化,如果您运行的模型是通过函数式编程形式实现,即权重参数在模型外部定义并通过入参传递给模型操作,那您可以直接通过配置`in_axes`的方式进行相应的批处理。而MindSpore框架为了提供便捷的模型定义功能,绝大部分`nn`接口的权重参数都在接口内部定义并初始化,这意味着模型中的权重参数在原始Vmap中无法对权重进行批处理,改造成通过入参传递的函数式实现需要额外工作量。不过您不必担心,MindSpore的`vmap`接口已经替您优化了该场景。您只需要将运行的多个模型实例以`CellList`的形式传入给`vmap`,框架即可自动实现权重参数的批处理。\n",
    "\n",
    "让我们演示如何使用一组简单的CNN模型来实现模型集成推理和训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LeNet5(nn.Cell):\n",
    "    \"\"\"\n",
    "    LeNet-5网络结构\n",
    "    \"\"\"\n",
    "    def __init__(self, num_class=10, num_channel=1):\n",
    "        super(LeNet5, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')\n",
    "        self.fc1 = nn.Dense(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Dense(120, 84)\n",
    "        self.fc3 = nn.Dense(84, num_class)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.flatten = nn.Flatten()\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.max_pool2d(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.max_pool2d(x)\n",
    "        x = self.flatten(x)\n",
    "        x = self.fc1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "假设我们正在验证同一模型架构在不同权重参数下的效果,让我们模拟四个已经训练好的模型实例和一份batch大小为16,尺寸为32 x 32的虚拟图像数据集的`minibatch`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "net1 = LeNet5()\n",
    "net2 = LeNet5()\n",
    "net3 = LeNet5()\n",
    "net4 = LeNet5()\n",
    "\n",
    "minibatch = Tensor(mnp.randn(3, 1, 32, 32), mindspore.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "相较于利用for循环分别运行各个模型后将预测结果集合到一起,Vmap能够一次运行获得多个模型的预测结果。\n",
    "\n",
    "> 注意,由于vmap的实现机制,会对设备运行内存有要求,使用vmap可能会占用更多内存,请用户根据实际场景使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tensor(shape=[4, 3, 10], dtype=Float32, value=\n",
       "[[[ 4.66281290e-06, -7.24548045e-06,  8.68147254e-07 ...  1.42438457e-05,  1.49375774e-05, -1.18535736e-05],\n",
       "  [ 9.10962353e-06, -5.63606591e-06, -7.06250285e-06 ...  1.68580664e-05,  1.41603141e-05, -3.55220163e-06],\n",
       "  [ 1.11184154e-05, -6.08020900e-06, -5.08124231e-06 ...  1.37913748e-05,  1.20597506e-05, -1.01803471e-05]],\n",
       " [[ 3.22165624e-06,  6.22022753e-06,  2.60713023e-07 ... -1.53302244e-05,  2.34313102e-05, -4.16413786e-06],\n",
       "  [ 2.82950850e-06,  1.54561894e-06,  5.19753303e-06 ... -1.53819674e-05,  1.58681542e-05, -7.10185304e-07],\n",
       "  [ 1.77780748e-07,  4.33479636e-06, -1.35376536e-06 ... -1.06113021e-05,  1.58355688e-05, -5.78900153e-06]],\n",
       " [[ 6.66864071e-06, -1.99870119e-05, -1.30958688e-05 ...  3.68208202e-06, -9.69678968e-06,  9.59075351e-06],\n",
       "  [ 7.99765985e-06, -1.16931469e-05, -1.06589669e-05 ... -1.24687813e-06, -8.65744005e-06,  6.81729716e-06],\n",
       "  [ 6.87587362e-06, -1.23972441e-05, -1.05251866e-05 ...  1.44004912e-06, -5.40550172e-06,  6.71799853e-06]],\n",
       " [[-3.44783439e-06,  2.32537104e-07, -8.64402864e-06 ...  3.52633970e-06, -6.27670488e-06,  3.27721250e-06],\n",
       "  [-6.90392517e-06, -9.97693860e-07, -6.48076320e-06 ...  7.61923275e-07, -2.54563452e-06,  3.08638573e-06],\n",
       "  [-3.78440518e-06,  3.93633945e-06, -7.90367903e-06 ...  5.13138957e-07, -4.50420839e-06,  2.13702333e-06]]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nets = nn.CellList([net1, net2, net3, net4])\n",
    "vmap(nets, in_axes=None)(minibatch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "又或者,我们期望得到多个模型分别执行不同`minibatch`数据的预测结果。\n",
    "\n",
    "> 模型集成场景下,`vmap`的第一个入参应为`CellList`类型,需要确保每个模型的架构完全相同,否则无法保证计算正确性,如果`in_axes`不为`None`是需保证模型数量与映射轴索引对应的`axis_size`一致,以实现一一映射关系。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tensor(shape=[4, 3, 10], dtype=Float32, value=\n",
       "[[[ 6.52808285e-06, -4.15002341e-06, -3.80283609e-06 ...  1.54428089e-05,  1.44425348e-05, -9.00016857e-06],\n",
       "  [ 7.39091365e-06, -5.19960076e-06,  3.83916813e-07 ...  1.67857870e-05,  1.80104271e-05, -1.56435199e-05],\n",
       "  [ 1.11604741e-05, -7.59019804e-06,  2.54263796e-07 ...  1.21071571e-05,  1.66683039e-05, -1.09967377e-05]],\n",
       " [[ 1.48978233e-06,  1.02267529e-06,  1.33801677e-06 ... -1.32894393e-05,  1.36311328e-05, -3.29658405e-06],\n",
       "  [ 1.09956818e-06, -5.06103561e-07,  3.04885953e-06 ... -1.76028752e-05,  1.66466998e-05, -1.17290392e-06],\n",
       "  [ 2.96090502e-06,  1.87074147e-06,  5.76813818e-06 ... -1.09994007e-05,  1.35614964e-05, -2.19983576e-06]],\n",
       " [[ 6.74323928e-06, -1.03955799e-05, -6.92168396e-06 ...  4.88165415e-06, -5.40378596e-06,  3.09346888e-06],\n",
       "  [ 7.28906161e-06, -1.34921102e-05, -1.00995640e-05 ...  9.44596650e-07, -6.40979761e-06,  1.26146606e-05],\n",
       "  [ 9.43304440e-06, -1.61852931e-05, -1.16265892e-05 ...  5.31926253e-06, -1.28484417e-05,  8.03831313e-07]],\n",
       " [[-5.51165886e-06, -1.09487860e-06, -6.07781249e-06 ...  7.51453626e-06, -3.29403338e-06,  3.45475746e-06],\n",
       "  [-6.27516283e-06,  1.40756754e-06, -9.18502155e-06 ...  4.16079911e-06, -5.30383022e-06,  5.12517454e-06],\n",
       "  [-6.19608954e-06,  5.12868655e-06, -1.00337056e-05 ...  2.93281119e-07, -6.52256404e-06,  3.62988635e-06]]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "minibatchs = Tensor(mnp.randn(4, 3, 1, 32, 32), mindspore.float32)\n",
    "vmap(nets, in_axes=0)(minibatchs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "除了支持模型集成推理外,结合Vmap特性同样能够实现模型集成训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 1 - loss: [2.3025837 2.3025882 2.3025858 2.3025842]\n",
      "Step 2 - loss: [2.260927  2.301028  2.2992857 2.2976868]\n",
      "Step 3 - loss: [1.8539654 2.2993202 2.2951114 2.2899477]\n",
      "Step 4 - loss: [0.77165794 2.2973287  2.288719   2.2726345 ]\n",
      "Step 5 - loss: [0.9397469 2.2948549 2.2777178 2.2313874]\n",
      "Step 6 - loss: [0.6747699 2.29158   2.2579656 2.1410708]\n",
      "Step 7 - loss: [0.64673084 2.2870557  2.2232006  1.966973  ]\n",
      "Step 8 - loss: [1.0506033 2.2806385 2.1645374 1.6848679]\n",
      "Step 9 - loss: [0.612196  2.2714498 2.0706694 1.3499321]\n",
      "Step 10 - loss: [0.8843982 2.258316  1.9299208 1.1472267]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Tensor(shape=[4, 3, 10], dtype=Float32, value=\n",
       "[[[-1.91058636e+01, -1.92182674e+01,  1.06328402e+01 ... -1.87287464e+01, -1.87855473e+01, -2.02504387e+01],\n",
       "  [-1.94767399e+01, -1.95909595e+01,  1.08379564e+01 ... -1.90921249e+01, -1.91503220e+01, -2.06434765e+01],\n",
       "  [-1.96521702e+01, -1.97674465e+01,  1.09355783e+01 ... -1.92643051e+01, -1.93227654e+01, -2.08293762e+01]],\n",
       " [[-4.07293849e-02, -4.27918807e-02,  5.22112176e-02 ... -4.67570126e-02, -3.88025381e-02,  4.88412194e-02],\n",
       "  [-3.91553082e-02, -4.11494374e-02,  5.00433967e-02 ... -4.48847115e-02, -3.73134986e-02,  4.68519926e-02],\n",
       "  [-3.80369201e-02, -3.99325565e-02,  4.84890938e-02 ... -4.35365662e-02, -3.62745039e-02,  4.54225838e-02]],\n",
       " [[-5.08784056e-01, -5.05123973e-01,  5.20882547e-01 ...  4.72596169e-01, -5.00697553e-01, -4.60489392e-01],\n",
       "  [-4.80103493e-01, -4.76664037e-01,  4.91507798e-01 ...  4.46062207e-01, -4.72493649e-01, -4.34652239e-01],\n",
       "  [-4.81168061e-01, -4.77702975e-01,  4.92583781e-01 ...  4.47029382e-01, -4.73524809e-01, -4.35579300e-01]],\n",
       " [[-3.66236401e+00, -3.25362825e+00,  3.51312804e+00 ...  3.77490187e+00, -3.36864424e+00, -3.34358120e+00],\n",
       "  [-3.49160767e+00, -3.10209608e+00,  3.34935308e+00 ...  3.59894991e+00, -3.21167707e+00, -3.18782210e+00],\n",
       "  [-3.57623625e+00, -3.17717075e+00,  3.43059325e+00 ...  3.68615556e+00, -3.28948307e+00, -3.26504302e+00]]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from mindspore import ParameterTuple\n",
    "\n",
    "class TrainOneStepNet(nn.Cell):\n",
    "    def __init__(self, net, lr):\n",
    "        super(TrainOneStepNet, self).__init__()\n",
    "        self.loss_fn = nn.WithLossCell(net, nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean'))\n",
    "        self.weight = ParameterTuple(net.trainable_params())\n",
    "        self.adam_optim = nn.Adam(self.weight, learning_rate=lr, use_amsgrad=True)\n",
    "\n",
    "    def construct(self, batch, targets):\n",
    "        loss = self.loss_fn(batch, targets)\n",
    "        grad_weights = grad(self.loss_fn, grad_position=None, weights=self.weight)(batch, targets)\n",
    "        self.adam_optim(grad_weights)\n",
    "        return loss\n",
    "\n",
    "train_net1 = TrainOneStepNet(net1, lr=1e-2)\n",
    "train_net2 = TrainOneStepNet(net2, lr=1e-3)\n",
    "train_net3 = TrainOneStepNet(net3, lr=2e-3)\n",
    "train_net4 = TrainOneStepNet(net4, lr=3e-3)\n",
    "\n",
    "train_nets = nn.CellList([train_net1, train_net2, train_net3, train_net4])\n",
    "model_ensembling_train_one_step = vmap(train_nets)\n",
    "\n",
    "images = Tensor(mnp.randn(4, 3, 1, 32, 32), mindspore.float32)\n",
    "labels = Tensor(mnp.randint(1, 10, (4, 3)), mindspore.int32)\n",
    "\n",
    "for i in range(1, 11):\n",
    "    loss = model_ensembling_train_one_step(images, labels)\n",
    "    print(\"Step {} - loss: {}\".format(i, loss))\n",
    "\n",
    "vmap(nets, in_axes=None)(minibatch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "经过模型集成训练的模型除了可以集成推理之外,仍然可以单独进行推理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tensor(shape=[3, 10], dtype=Float32, value=\n",
       "[[-1.91058636e+01, -1.92182674e+01,  1.06328402e+01 ... -1.87287483e+01, -1.87855473e+01, -2.02504387e+01],\n",
       " [-1.94767399e+01, -1.95909595e+01,  1.08379564e+01 ... -1.90921249e+01, -1.91503220e+01, -2.06434765e+01],\n",
       " [-1.96521702e+01, -1.97674465e+01,  1.09355783e+01 ... -1.92643051e+01, -1.93227654e+01, -2.08293762e+01]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net1(minibatch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 更多实践案例\n",
    "\n",
    "- [差分隐私场景下利用Vmap加速逐样本梯度计算过程](https://gitee.com/mindspore/docs/blob/r2.3.0rc2/docs/sample_code/per_sample_gradient.py);\n",
    "\n",
    "- 科学计算领域AI电磁模型结合Vmap加速点源时域麦克斯韦方程求解;\n",
    "\n",
    "- 强化学习场景下结合Vmap实现多Agent并行训练和推理;\n",
    "\n",
    "- 自动微分场景中基于Vmap提供高效的雅可比矩阵计算接口[jacrev](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/api_python/mindspore/mindspore.jacrev.html)和[jacfwd](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/api_python/mindspore/mindspore.jacfwd.html);\n",
    "\n",
    "### 总结\n",
    "\n",
    "本教程重点在于介绍Vmap的场景使用说明,本质上自动向量化并非将循环逻辑执行于函数外部,而是将循环下沉至函数的各个基元操作中,并将映射轴信息在基元操作间传递,从而保证计算逻辑的正确性。Vmap的性能收益主要来自于各个基元操作所对应的`VmapRule`实现,由于循环下沉至算子层级,因而更容易结合并行技术进行性能优化,如果您有自定义算子的场景也可以尝试为自定义算子实现特定的`VmapRule`,从而获得更好的性能。对于性能极致追求的场景还可以再结合图算融合特性进行优化。\n",
    "\n",
    "> Vmap特性当前支持GPU、CPU平台,Ascend平台功能仍在不断完善中。\n",
    ">\n",
    "> 在Vmap包含控制流的场景中,当前仅支持每个批处理分支具有相同处理操作或控制流逻辑中所有变量均无切分轴的场景。\n",
    "\n",
    "通过上列用例的展示,相信您已经对Vmap的能力有了大致了解,但Vmap的使用场景绝不仅限于此教程陈列的内容。此教程仅起到抛砖引玉的作用,您可以尝试更多有意思的场景,也欢迎加入我们的讨论和工作中。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}