{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 自动求导\n", "\n", "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.8/tutorials/zh_cn/advanced/network/mindspore_derivation.ipynb) \n", "[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.8/tutorials/zh_cn/advanced/network/mindspore_derivation.py) \n", "[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.8/tutorials/source_zh_cn/advanced/network/derivation.ipynb)\n", "\n", "`mindspore.ops`模块提供的`GradOperation`接口可以生成网络模型的梯度。本文主要介绍如何使用`GradOperation`接口进行一阶、二阶求导,以及如何停止计算梯度。\n", "\n", "> 更多求导接口相关信息可参考[API文档](https://mindspore.cn/docs/zh-CN/r1.8/api_python/ops/mindspore.ops.GradOperation.html#mindspore.ops.GradOperation)。\n", "\n", "## 一阶求导\n", "\n", "计算一阶导数方法:`mindspore.ops.GradOperation()`,其中参数使用方式为:\n", "\n", "- `get_all`:为`False`时,只会对第一个输入求导;为`True`时,会对所有输入求导。\n", "- `get_by_list:`为`False`时,不会对权重求导;为`True`时,会对权重求导。\n", "- `sens_param`:对网络的输出值做缩放以改变最终梯度,故其维度与输出维度保持一致;\n", "\n", "下面我们先使用[MatMul](https://mindspore.cn/docs/zh-CN/r1.8/api_python/ops/mindspore.ops.MatMul.html#mindspore.ops.MatMul)算子构建自定义网络模型`Net`,再对其进行一阶求导,通过这样一个例子对`GradOperation`接口的使用方式做简单介绍,即公式:\n", "\n", "$$f(x, y)=(x * z) * y \\tag{1}$$\n", "\n", "首先我们要定义网络模型`Net`、输入`x`和输入`y`:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import mindspore.nn as nn\n", "import mindspore.ops as ops\n", "import mindspore as ms\n", "\n", "# 定义输入x和y\n", "x = ms.Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=ms.float32)\n", "y = ms.Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=ms.float32)\n", "\n", "class Net(nn.Cell):\n", " \"\"\"定义矩阵相乘网络Net\"\"\"\n", "\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.matmul = ops.MatMul()\n", " self.z = ms.Parameter(ms.Tensor(np.array([1.0], np.float32)), name='z')\n", "\n", " def construct(self, x, y):\n", " x = x * self.z\n", " out = self.matmul(x, y)\n", " return out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 对输入进行求导\n", "\n", "对输入值进行求导,代码如下:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[4.5099998 2.7 3.6000001]\n", " [4.5099998 2.7 3.6000001]]\n" ] } ], "source": [ "class GradNetWrtX(nn.Cell):\n", " \"\"\"定义网络输入的一阶求导\"\"\"\n", "\n", " def __init__(self, net):\n", " super(GradNetWrtX, self).__init__()\n", " self.net = net\n", " self.grad_op = ops.GradOperation()\n", "\n", " def construct(self, x, y):\n", " gradient_function = self.grad_op(self.net)\n", " return gradient_function(x, y)\n", "\n", "output = GradNetWrtX(Net())(x, y)\n", "print(output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来我们对上面的结果做一个解释。为便于分析,我们把上面的输入`x`、`y`以及权重`z`表示成如下形式:\n", "\n", "```text\n", "x = ms.Tensor([[x1, x2, x3], [x4, x5, x6]])\n", "y = ms.Tensor([[y1, y2, y3], [y4, y5, y6], [y7, y8, y9]])\n", "z = ms.Tensor([z])\n", "```\n", "\n", "根据MatMul算子定义可得前向结果:\n", "\n", "$$output = [[(x_1 \\cdot y_1 + x_2 \\cdot y_4 + x_3 \\cdot y_7) \\cdot z, (x_1 \\cdot y_2 + x_2 \\cdot y_5 + x_3 \\cdot y_8) \\cdot z, (x_1 \\cdot y_3 + x_2 \\cdot y_6 + x_3 \\cdot y_9) \\cdot z],$$\n", "\n", "$$[(x_4 \\cdot y_1 + x_5 \\cdot y_4 + x_6 \\cdot y_7) \\cdot z, (x_4 \\cdot y_2 + x_5 \\cdot y_5 + x_6 \\cdot y_8) \\cdot z, (x_4 \\cdot y_3 + x_5 \\cdot y_6 + x_6 \\cdot y_9) \\cdot z]] \\tag{2}$$\n", "\n", "梯度计算时由于MindSpore采用的是Reverse自动微分机制,会对输出结果求和后再对输入`x`求导:\n", "\n", "1. 求和公式:\n", "\n", "$$\\sum{output} = [(x_1 \\cdot y_1 + x_2 \\cdot y_4 + x_3 \\cdot y_7) + (x_1 \\cdot y_2 + x_2 \\cdot y_5 + x_3 \\cdot y_8) + (x_1 \\cdot y_3 + x_2 \\cdot y_6 + x_3 \\cdot y_9)$$\n", "\n", "$$+ (x_4 \\cdot y_1 + x_5 \\cdot y_4 + x_6 \\cdot y_7) + (x_4 \\cdot y_2 + x_5 \\cdot y_5 + x_6 \\cdot y_8) + (x_4 \\cdot y_3 + x_5 \\cdot y_6 + x_6 \\cdot y_9)] \\cdot z \\tag{3}$$\n", "\n", "2. 求导公式:\n", "\n", "$$\\frac{\\mathrm{d}(\\sum{output})}{\\mathrm{d}x} = [[(y_1 + y_2 + y_3) \\cdot z, (y_4 + y_5 + y_6) \\cdot z, (y_7 + y_8 + y_9) \\cdot z],$$\n", "\n", "$$[(y_1 + y_2 + y_3) \\cdot z, (y_4 + y_5 + y_6) \\cdot z, (y_7 + y_8 + y_9) \\cdot z]] \\tag{4}$$\n", "\n", "3. 计算结果:\n", "\n", "$$\\frac{\\mathrm{d}(\\sum{output})}{\\mathrm{d}x} = [[4.51 \\quad 2.7 \\quad 3.6] [4.51 \\quad 2.7 \\quad 3.6]] \\tag{5}$$\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> 若考虑对`x`、`y`输入求导,只需在`GradNetWrtX`中设置`self.grad_op = GradOperation(get_all=True)`。\n", "\n", "### 对权重进行求导\n", "\n", "对权重进行求导,示例代码如下:\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[21.536]\n" ] } ], "source": [ "class GradNetWrtZ(nn.Cell):\n", " \"\"\"定义网络权重的一阶求导\"\"\"\n", "\n", " def __init__(self, net):\n", " super(GradNetWrtZ, self).__init__()\n", " self.net = net\n", " self.params = ms.ParameterTuple(net.trainable_params())\n", " self.grad_op = ops.GradOperation(get_by_list=True)\n", "\n", " def construct(self, x, y):\n", " gradient_function = self.grad_op(self.net, self.params)\n", " return gradient_function(x, y)\n", "\n", "output = GradNetWrtZ(Net())(x, y)\n", "print(output[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下面我们通过公式对上面的结果做一个解释。对权重的求导公式为:\n", "\n", "$$\\frac{\\mathrm{d}(\\sum{output})}{\\mathrm{d}z} = (x_1 \\cdot y_1 + x_2 \\cdot y_4 + x_3 \\cdot y_7) + (x_1 \\cdot y_2 + x_2 \\cdot y_5 + x_3 \\cdot y_8) + (x_1 \\cdot y_3 + x_2 \\cdot y_6 + x_3 \\cdot y_9)$$\n", "\n", "$$+ (x_4 \\cdot y_1 + x_5 \\cdot y_4 + x_6 \\cdot y_7) + (x_4 \\cdot y_2 + x_5 \\cdot y_5 + x_6 \\cdot y_8) + (x_4 \\cdot y_3 + x_5 \\cdot y_6 + x_6 \\cdot y_9) \\tag{6}$$\n", "\n", "计算结果:\n", "\n", "$$\\frac{\\mathrm{d}(\\sum{output})}{\\mathrm{d}z} = [2.1536e+01] \\tag{7}$$\n", "\n", "### 梯度值缩放\n", "\n", "可以通过`sens_param`参数控制梯度值的缩放:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[2.211 0.51 1.49 ]\n", " [5.588 2.68 4.07 ]]\n" ] } ], "source": [ "class GradNetWrtN(nn.Cell):\n", " \"\"\"定义网络的一阶求导,控制梯度值缩放\"\"\"\n", " def __init__(self, net):\n", " super(GradNetWrtN, self).__init__()\n", " self.net = net\n", " self.grad_op = ops.GradOperation(sens_param=True)\n", "\n", " # 定义梯度值缩放\n", " self.grad_wrt_output = ms.Tensor([[0.1, 0.6, 0.2], [0.8, 1.3, 1.1]], dtype=ms.float32)\n", "\n", " def construct(self, x, y):\n", " gradient_function = self.grad_op(self.net)\n", " return gradient_function(x, y, self.grad_wrt_output)\n", "\n", "output = GradNetWrtN(Net())(x, y)\n", "print(output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "为了方便对上面的结果进行解释,我们把`self.grad_wrt_output`记作如下形式:\n", "\n", "```text\n", "self.grad_wrt_output = ms.Tensor([[s1, s2, s3], [s4, s5, s6]])\n", "```\n", "\n", "缩放后的输出值为原输出值与`self.grad_wrt_output`对应元素的乘积,公式为:\n", "\n", "$$output = [[(x_1 \\cdot y_1 + x_2 \\cdot y_4 + x_3 \\cdot y_7) \\cdot z \\cdot s_1, (x_1 \\cdot y_2 + x_2 \\cdot y_5 + x_3 \\cdot y_8) \\cdot z \\cdot s_2, (x_1 \\cdot y_3 + x_2 \\cdot y_6 + x_3 \\cdot y_9) \\cdot z \\cdot s_3], $$\n", "\n", "$$[(x_4 \\cdot y_1 + x_5 \\cdot y_4 + x_6 \\cdot y_7) \\cdot z \\cdot s_4, (x_4 \\cdot y_2 + x_5 \\cdot y_5 + x_6 \\cdot y_8) \\cdot z \\cdot s_5, (x_4 \\cdot y_3 + x_5 \\cdot y_6 + x_6 \\cdot y_9) \\cdot z \\cdot s_6]] \\tag{8}$$\n", "\n", "求导公式变为输出值总和对`x`的每个元素求导:\n", "\n", "$$\\frac{\\mathrm{d}(\\sum{output})}{\\mathrm{d}x} = [[(s_1 \\cdot y_1 + s_2 \\cdot y_2 + s_3 \\cdot y_3) \\cdot z, (s_1 \\cdot y_4 + s_2 \\cdot y_5 + s_3 \\cdot y_6) \\cdot z, (s_1 \\cdot y_7 + s_2 \\cdot y_8 + s_3 \\cdot y_9) \\cdot z],$$\n", "\n", "$$[(s_4 \\cdot y_1 + s_5 \\cdot y_2 + s_6 \\cdot y_3) \\cdot z, (s_4 \\cdot y_4 + s_5 \\cdot y_5 + s_6 \\cdot y_6) \\cdot z, (s_4 \\cdot y_7 + s_5 \\cdot y_8 + s_6 \\cdot y_9) \\cdot z]] \\tag{9}$$\n", "\n", "计算结果:\n", "\n", "$$\\frac{\\mathrm{d}(\\sum{output})}{\\mathrm{d}x} = [[2.211 \\quad 0.51 \\quad 1.49][5.588 \\quad 2.68 \\quad 4.07]] \\tag{10}$$\n", "\n", "### 停止计算梯度\n", "\n", "我们可以使用`stop_gradient`来停止计算指定算子的梯度,从而消除该算子对梯度的影响。\n", "\n", "在上面一阶求导使用的矩阵相乘网络模型的基础上,我们再增加一个算子`out2`并禁止计算其梯度,得到自定义网络`Net2`,然后看一下对输入的求导结果情况。\n", "\n", "示例代码如下:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[4.5099998 2.7 3.6000001]\n", " [4.5099998 2.7 3.6000001]]\n" ] } ], "source": [ "class Net(nn.Cell):\n", "\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.matmul = ops.MatMul()\n", "\n", " def construct(self, x, y):\n", " out1 = self.matmul(x, y)\n", " out2 = self.matmul(x, y)\n", " out2 = ops.stop_gradient(out2) # 停止计算out2算子的梯度\n", " out = out1 + out2\n", " return out\n", "\n", "class GradNetWrtX(nn.Cell):\n", "\n", " def __init__(self, net):\n", " super(GradNetWrtX, self).__init__()\n", " self.net = net\n", " self.grad_op = ops.GradOperation()\n", "\n", " def construct(self, x, y):\n", " gradient_function = self.grad_op(self.net)\n", " return gradient_function(x, y)\n", "\n", "output = GradNetWrtX(Net())(x, y)\n", "print(output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的打印可以看出,由于对`out2`设置了`stop_gradient`, 所以`out2`没有对梯度计算有任何的贡献,其输出结果与未加`out2`算子时一致。\n", "\n", "下面我们删除`out2 = stop_gradient(out2)`,再来看一下输出结果。示例代码为:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[9.0199995 5.4 7.2000003]\n", " [9.0199995 5.4 7.2000003]]\n" ] } ], "source": [ "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.matmul = ops.MatMul()\n", "\n", " def construct(self, x, y):\n", " out1 = self.matmul(x, y)\n", " out2 = self.matmul(x, y)\n", " # out2 = stop_gradient(out2)\n", " out = out1 + out2\n", " return out\n", "\n", "class GradNetWrtX(nn.Cell):\n", " def __init__(self, net):\n", " super(GradNetWrtX, self).__init__()\n", " self.net = net\n", " self.grad_op = ops.GradOperation()\n", "\n", " def construct(self, x, y):\n", " gradient_function = self.grad_op(self.net)\n", " return gradient_function(x, y)\n", "\n", "output = GradNetWrtX(Net())(x, y)\n", "print(output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "打印结果可以看出,在我们把`out2`算子的梯度也计算进去之后,由于`out2`和`out1`算子完全相同,因此它们产生的梯度也完全相同,所以我们可以看到,结果中每一项的值都变为了原来的两倍(存在精度误差)。\n", "\n", "### 自定义反向传播函数\n", "\n", "使用MindSpore构建神经网络时,需要继承 `nn.Cell` 类。当网络中存在一些尚未定义反向传播规则的操作,或者当我们想控制整个网络的梯度计算过程时,可以使用自定义 `nn.Cell` 对象反向传播函数的功能,形式为:\n", "\n", "```python\n", "def bprop(self, ..., out, dout):\n", " return ...\n", "```\n", "\n", "- 输入参数: 与正向部分相同的输入参数再加上 `out` 和 `dout` , `out` 表示正向部分的计算结果, `dout` 表示回传到该 `nn.Cell` 对象的梯度。\n", "- 返回值: 关于正向部分每个输入的梯度,所以返回值的数量需要与正向部分输入的数量相同。\n", "\n", "完整示例如下:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(Tensor(shape=[2, 3], dtype=Float32, value=\n", "[[ 1.50000000e+00, 1.60000002e+00, 1.39999998e+00],\n", " [ 2.20000005e+00, 2.29999995e+00, 2.09999990e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=\n", "[[ 1.00999999e+00, 1.29999995e+00, 2.09999990e+00],\n", " [ 1.10000002e+00, 1.20000005e+00, 2.29999995e+00],\n", " [ 3.09999990e+00, 2.20000005e+00, 4.30000019e+00]]))\n" ] } ], "source": [ "import mindspore.nn as nn\n", "import mindspore as ms\n", "import mindspore.ops as ops\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.matmul = ops.MatMul()\n", "\n", " def construct(self, x, y):\n", " out = self.matmul(x, y)\n", " return out\n", "\n", " def bprop(self, x, y, out, dout):\n", " dx = x + 1\n", " dy = y + 1\n", " return dx, dy\n", "\n", "\n", "class GradNet(nn.Cell):\n", " def __init__(self, net):\n", " super(GradNet, self).__init__()\n", " self.net = net\n", " self.grad_op = ops.GradOperation(get_all=True)\n", "\n", " def construct(self, x, y):\n", " gradient_function = self.grad_op(self.net)\n", " return gradient_function(x, y)\n", "\n", "\n", "x = ms.Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=ms.float32)\n", "y = ms.Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=ms.float32)\n", "out = GradNet(Net())(x, y)\n", "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "约束与限制:\n", "\n", "- 当 `bprop` 函数的返回值数量为1时,也需要写成tuple的形式,即 `return (dx,)` 。\n", "- 图模式下, `bprop` 函数需要转换成图IR,所以需要遵循静态图语法,请参考[静态图语法支持](https://www.mindspore.cn/docs/zh-CN/r1.8/note/static_graph_syntax_support.html)。\n", "- 只支持返回关于正向部分输入的梯度,不支持返回关于 `Parameter` 的梯度。\n", "- 不支持在 `bprop` 中使用 `Parameter` 。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 高阶求导\n", "\n", "高阶微分在AI支持科学计算、二阶优化等领域均有应用。如分子动力学模拟中,利用神经网络训练势能时,损失函数中需计算神经网络输出对输入的导数,则反向传播便存在损失函数对输入、权重的二阶交叉导数。\n", "\n", "此外,AI求解微分方程(如PINNs方法)还会存在输出对输入的二阶导数。又如二阶优化中,为了能够让神经网络快速收敛,牛顿法等需计算损失函数对权重的二阶导数。\n", "\n", "MindSpore可通过多次求导的方式支持高阶导数,下面通过几类例子展开阐述。\n", "\n", "### 单输入单输出高阶导数\n", "\n", "例如Sin算子,其公式为:\n", "\n", "$$f(x) = sin(x) \\tag{1}$$\n", "\n", "其一阶导数是:\n", "\n", "$$f'(x) = cos(x) \\tag{2}$$\n", "\n", "其二阶导数为:\n", "\n", "$$f''(x) = cos'(x) = -sin(x) \\tag{3}$$\n", "\n", "其二阶导数(-Sin)实现如下:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[-0.]\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.nn as nn\n", "import mindspore.ops as ops\n", "import mindspore as ms\n", "\n", "class Net(nn.Cell):\n", " \"\"\"前向网络模型\"\"\"\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.sin = ops.Sin()\n", "\n", " def construct(self, x):\n", " out = self.sin(x)\n", " return out\n", "\n", "class Grad(nn.Cell):\n", " \"\"\"一阶求导\"\"\"\n", " def __init__(self, network):\n", " super(Grad, self).__init__()\n", " self.grad = ops.GradOperation()\n", " self.network = network\n", "\n", " def construct(self, x):\n", " gout = self.grad(self.network)(x)\n", " return gout\n", "\n", "class GradSec(nn.Cell):\n", " \"\"\"二阶求导\"\"\"\n", " def __init__(self, network):\n", " super(GradSec, self).__init__()\n", " self.grad = ops.GradOperation()\n", " self.network = network\n", "\n", " def construct(self, x):\n", " gout = self.grad(self.network)(x)\n", " return gout\n", "\n", "x_train = ms.Tensor(np.array([3.1415926]), dtype=ms.float32)\n", "\n", "net = Net()\n", "firstgrad = Grad(net)\n", "secondgrad = GradSec(firstgrad)\n", "output = secondgrad(x_train)\n", "\n", "# 打印结果\n", "result = np.around(output.asnumpy(), decimals=2)\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的打印结果可以看出,$-sin(3.1415926)$的值接近于$0$。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 单输入多输出高阶导数\n", "\n", "对如下公式求导:\n", "\n", "$$f(x) = (f_1(x), f_2(x)) \\tag{1}$$\n", "\n", "其中:\n", "\n", "$$f_1(x) = sin(x) \\tag{2}$$\n", "\n", "$$f_2(x) = cos(x) \\tag{3}$$\n", "\n", "梯度计算时由于MindSpore采用的是反向自动微分机制, 会对输出结果求和后再对输入求导。 因此其一阶导数是:\n", "\n", "$$f'(x) = cos(x) -sin(x) \\tag{4}$$\n", "\n", "其二阶导数为:\n", "\n", "$$f''(x) = -sin(x) - cos(x) \\tag{5}$$" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1.]\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.nn as nn\n", "import mindspore.ops as ops\n", "import mindspore as ms\n", "\n", "class Net(nn.Cell):\n", " \"\"\"前向网络模型\"\"\"\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.sin = ops.Sin()\n", " self.cos = ops.Cos()\n", "\n", " def construct(self, x):\n", " out1 = self.sin(x)\n", " out2 = self.cos(x)\n", " return out1, out2\n", "\n", "class Grad(nn.Cell):\n", " \"\"\"一阶求导\"\"\"\n", " def __init__(self, network):\n", " super(Grad, self).__init__()\n", " self.grad = ops.GradOperation()\n", " self.network = network\n", "\n", " def construct(self, x):\n", " gout = self.grad(self.network)(x)\n", " return gout\n", "\n", "class GradSec(nn.Cell):\n", " \"\"\"二阶求导\"\"\"\n", " def __init__(self, network):\n", " super(GradSec, self).__init__()\n", " self.grad = ops.GradOperation()\n", " self.network = network\n", "\n", " def construct(self, x):\n", " gout = self.grad(self.network)(x)\n", " return gout\n", "\n", "x_train = ms.Tensor(np.array([3.1415926]), dtype=ms.float32)\n", "\n", "net = Net()\n", "firstgrad = Grad(net)\n", "secondgrad = GradSec(firstgrad)\n", "output = secondgrad(x_train)\n", "\n", "# 打印结果\n", "result = np.around(output.asnumpy(), decimals=2)\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的打印结果可以看出,$-sin(3.1415926) - cos(3.1415926)$的值接近于$1$。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 多输入多输出高阶导数\n", "\n", "对如下公式求导:\n", "\n", "$$f(x, y) = (f_1(x, y), f_2(x, y)) \\tag{1}$$\n", "\n", "其中:\n", "\n", "$$f_1(x, y) = sin(x) - cos(y) \\tag{2}$$\n", "\n", "$$f_2(x, y) = cos(x) - sin(y) \\tag{3}$$\n", "\n", "梯度计算时由于MindSpore采用的是反向自动微分机制, 会对输出结果求和后再对输入求导。\n", "\n", "求和:\n", "\n", "$$\\sum{output} = sin(x) + cos(x) - sin(y) - cos(y) \\tag{4}$$\n", "\n", "输出和关于输入$x$的一阶导数为:\n", "\n", "$$\\dfrac{\\mathrm{d}\\sum{output}}{\\mathrm{d}x} = cos(x) - sin(x) \\tag{5}$$\n", "\n", "输出和关于输入$x$的二阶导数为:\n", "\n", "$$\\dfrac{\\mathrm{d}\\sum{output}^{2}}{\\mathrm{d}^{2}x} = -sin(x) - cos(x) \\tag{6}$$\n", "\n", "输出和关于输入$y$的一阶导数为:\n", "\n", "$$\\dfrac{\\mathrm{d}\\sum{output}}{\\mathrm{d}y} = -cos(y) + sin(y) \\tag{7}$$\n", "\n", "输出和关于输入$y$的二阶导数为:\n", "\n", "$$\\dfrac{\\mathrm{d}\\sum{output}^{2}}{\\mathrm{d}^{2}y} = sin(y) + cos(y) \\tag{8}$$" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1.]\n", "[-1.]\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.nn as nn\n", "import mindspore.ops as ops\n", "import mindspore as ms\n", "\n", "class Net(nn.Cell):\n", " \"\"\"前向网络模型\"\"\"\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.sin = ops.Sin()\n", " self.cos = ops.Cos()\n", "\n", " def construct(self, x, y):\n", " out1 = self.sin(x) - self.cos(y)\n", " out2 = self.cos(x) - self.sin(y)\n", " return out1, out2\n", "\n", "class Grad(nn.Cell):\n", " \"\"\"一阶求导\"\"\"\n", " def __init__(self, network):\n", " super(Grad, self).__init__()\n", " self.grad = ops.GradOperation(get_all=True)\n", " self.network = network\n", "\n", " def construct(self, x, y):\n", " gout = self.grad(self.network)(x, y)\n", " return gout\n", "\n", "class GradSec(nn.Cell):\n", " \"\"\"二阶求导\"\"\"\n", " def __init__(self, network):\n", " super(GradSec, self).__init__()\n", " self.grad = ops.GradOperation(get_all=True)\n", " self.network = network\n", "\n", " def construct(self, x, y):\n", " gout = self.grad(self.network)(x, y)\n", " return gout\n", "\n", "x_train = ms.Tensor(np.array([3.1415926]), dtype=ms.float32)\n", "y_train = ms.Tensor(np.array([3.1415926]), dtype=ms.float32)\n", "\n", "net = Net()\n", "firstgrad = Grad(net)\n", "secondgrad = GradSec(firstgrad)\n", "output = secondgrad(x_train, y_train)\n", "\n", "# 打印结果\n", "print(np.around(output[0].asnumpy(), decimals=2))\n", "print(np.around(output[1].asnumpy(), decimals=2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的打印结果可以看出,输出对输入$x$的二阶导数$-sin(3.1415926) - cos(3.1415926)$的值接近于$1$, 输出对输入$y$的二阶导数$sin(3.1415926) + cos(3.1415926)$的值接近于$-1$。\n", "\n", "> 由于不同计算平台的精度可能存在差异,因此本章节中的代码在不同平台上的执行结果会存在微小的差别。" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 5 }