{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Per-sample-gradients\n",
    "\n",
    "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3.0/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.3.0/tutorials/experts/zh_cn/func_programming/mindspore_per_sample_gradients.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3.0/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.3.0/tutorials/experts/zh_cn/func_programming/mindspore_per_sample_gradients.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.3.0/tutorials/experts/source_zh_cn/func_programming/per_sample_gradients.ipynb)\n",
    "\n",
    "计算per-sample-gradients是指计算一个批量样本中每个样本的梯度。在训练神经网络时,很多深度学习框架会计算批量样本的梯度,并利用批量样本的梯度更新网络参数。per-sample-gradients可以帮助我们在训练神经网络时,更准确地计算每个样本对网络参数的影响,从而更好地提高模型的训练效果。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在很多深度学习计算框架中,计算per-sample-gradients是一件很麻烦的事情,因为这些框架会直接累加整个批量样本的梯度。利用这些框架,我们可以想到一个简单的方法来计算per-sample-gradients,即计算批量样本中的每一个样本的预测值和标签值的损失,并计算该损失关于网络参数的梯度,但这个方法显然是很低效的。\n",
    "\n",
    "MindSpore为我们提供了更高效的方法来计算per-sample-gradients。\n",
    "\n",
    "我们以TD(0)(Temporal Difference)算法为例对计算per-sample-gradients的高效方法进行说明。TD(0)是一种基于时间差分的强化学习算法,它可以在没有环境模型的情况下学习最优策略。在TD(0)算法中,会根据当前的奖励,对值函数的估计值进行更新,TD(0)算法公式如下,\n",
    "\n",
    "$$V(S_{t}) = V(S_{t}) + \\alpha (R_{t+1} + \\gamma V(S_{t+1}) - V(S_{t}))$$\n",
    "\n",
    "其中$V(S_{t})$是当前的值函数估计值,$\\alpha$是学习率,$R_{t+1}$是在状态$S_{t}$下执行动作后获得的奖励,$\\gamma$是折扣因子,$V(S_{t+1})$是下一个状态$S_{t+1}$的值函数估计值,$R_{t+1} + \\gamma V(S_{t+1})$被称为TD目标,$R_{t+1} + \\gamma V(S_{t+1}) - V(S_{t})$被称为TD偏差。\n",
    "\n",
    "通过不断地使用TD(0)算法更新值函数估计值,可以逐步学习到最优策略,从而使在环境中获得的奖励最大化。\n",
    "\n",
    "在MindSpore中,将jit,vmap和grad组合在一起,我们可以得到更高效的方法来计算per-sample-gradients。\n",
    "\n",
    "下面对该方法进行介绍,假设在状态$s_{t}$时的估计值$v_{\\theta}$由一个线性函数进行参数化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import ops, Tensor, vmap, jit, grad\n",
    "\n",
    "\n",
    "value_fn = lambda theta, state: ops.tensor_dot(theta, state, axes=1)\n",
    "theta = Tensor([0.2, -0.2, 0.1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "考虑如下场景,从状态$s_{t}$转换到状态$s_{t+1}$,且在这个过程中,我们观察到的奖励为$r_{t+1}$。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "s_t = Tensor([2., 1., -2.])\n",
    "r_tp1 = Tensor(2.)\n",
    "s_tp1 = Tensor([1., 2., 0.])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "参数${\\theta}$的更新量的计算公式为:\n",
    "\n",
    "$$\\Delta{\\theta}=(r_{t+1} + v_{\\theta}(s_{t+1}) - v_{\\theta}(s_{t}))\\nabla v_{\\theta}(s_{t})$$\n",
    "\n",
    "参数${\\theta}$的更新量并不是任何损失函数的梯度,然而,它可以被认为是下面的伪损失函数的梯度(假设忽略目标值$r_{t+1} + v_{\\theta}(s_{t+1})$对计算$L(\\theta)$关于${\\theta}$的梯度的影响),\n",
    "\n",
    "$$L(\\theta) = [r_{t+1} + v_{\\theta}(s_{t+1}) - v_{\\theta}(s_{t})]^{2}$$\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "计算参数${\\theta}$的更新量(计算$L(\\theta)$关于${\\theta}$的梯度)时,我们需要使用`ops.stop_gradient`消除目标值$r_{t+1} + v_{\\theta}(s_{t+1})$对计算${\\theta}$梯度的影响,这可以使得在求导过程中,目标值$r_{t+1} + v_{\\theta}(s_{t+1})$不对${\\theta}$求导,以得到参数${\\theta}$正确的更新量。\n",
    "\n",
    "我们给出伪损失函数$L(\\theta)$在MindSpore中的实现,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def td_loss(theta, s_tm1, r_t, s_t):\n",
    "    v_t = value_fn(theta, s_t)\n",
    "    target = r_tp1 + value_fn(theta, s_tp1)\n",
    "    return (ops.stop_gradient(target) - v_t) ** 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将`td_loss`传入`grad`中,计算td_loss关于`theta`的梯度,即`theta`的更新量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-4. -8. -0.]\n"
     ]
    }
   ],
   "source": [
    "td_update = grad(td_loss)\n",
    "delta_theta = td_update(theta, s_t, r_tp1, s_tp1)\n",
    "print(delta_theta)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`td_update`仅根据一个样本,计算td_loss关于参数${\\theta}$的梯度,我们可以使用`vmap`对该函数进行矢量化,它会对所有的inputs和outputs添加一个批处理维度。现在,我们给出一批量的输入,并产生一批量的输出,输出批量中的每个输出元素都对应于输入批量中相应的输入元素。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-4. -8.  0.]\n",
      " [-4. -8.  0.]]\n"
     ]
    }
   ],
   "source": [
    "batched_s_t = ops.stack([s_t, s_t])\n",
    "batched_r_tp1 = ops.stack([r_tp1, r_tp1])\n",
    "batched_s_tp1 = ops.stack([s_tp1, s_tp1])\n",
    "batched_theta = ops.stack([theta, theta])\n",
    "\n",
    "per_sample_grads = vmap(td_update)\n",
    "batch_theta = ops.stack([theta, theta])\n",
    "delta_theta = per_sample_grads(batched_theta, batched_s_t, batched_r_tp1, batched_s_tp1)\n",
    "print(delta_theta)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在上面的例子中,我们需要手动地为`per_sample_grads`传递一批量的`theta`,但实际上,我们可以仅传入单个的`theta`,为了实现这一点,我们对`vmap`传入参数`in_axes`,在`in_axes`中,参数`theta`对应的位置被设置为`None`,其他参数对应的位置被设置为`0`。这使得我们仅需向除`theta`以外的其他参数添加一个额外的轴。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-4. -8.  0.]\n",
      " [-4. -8.  0.]]\n"
     ]
    }
   ],
   "source": [
    "inefficiecient_per_sample_grads = vmap(td_update, in_axes=(None, 0, 0, 0))\n",
    "delta_theta = inefficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)\n",
    "print(delta_theta)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "到这里,已经可以正确地计算每个样本的梯度了,但是我们还可以让计算过程变得更快些,我们使用`jit`调用`inefficiecient_per_sample_grads`,这会将`inefficiecient_per_sample_grads`编译为一张可调用的MindSpore图,这会提升它的运行效率。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-4. -8.  0.]\n",
      " [-4. -8.  0.]]\n"
     ]
    }
   ],
   "source": [
    "efficiecient_per_sample_grads = jit(inefficiecient_per_sample_grads)\n",
    "delta_theta = efficiecient_per_sample_grads(theta, batched_s_t, batched_r_tp1, batched_s_tp1)\n",
    "print(delta_theta)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}