自动求导

下载Notebook下载样例代码查看源文件

mindspore.ops模块提供的GradOperation接口可以生成网络模型的梯度。本文主要介绍如何使用GradOperation接口进行一阶、二阶求导,以及如何停止计算梯度。

更多求导接口相关信息可参考API文档

一阶求导

计算一阶导数方法:mindspore.ops.GradOperation(),其中参数使用方式为:

  • get_all:为False时,只会对第一个输入求导;为True时,会对所有输入求导。

  • get_by_list:False时,不会对权重求导;为True时,会对权重求导。

  • sens_param:对网络的输出值做缩放以改变最终梯度,故其维度与输出维度保持一致;

下面我们先使用MatMul算子构建自定义网络模型Net,再对其进行一阶求导,通过这样一个例子对GradOperation接口的使用方式做简单介绍,即公式:

\[f(x, y)=(x * z) * y \tag{1}\]

首先我们要定义网络模型Net、输入x和输入y

[1]:
import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore as ms

# 定义输入x和y
x = ms.Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=ms.float32)
y = ms.Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=ms.float32)

class Net(nn.Cell):
    """定义矩阵相乘网络Net"""

    def __init__(self):
        super(Net, self).__init__()
        self.matmul = ops.MatMul()
        self.z = ms.Parameter(ms.Tensor(np.array([1.0], np.float32)), name='z')

    def construct(self, x, y):
        x = x * self.z
        out = self.matmul(x, y)
        return out

对输入进行求导

对输入值进行求导,代码如下:

[2]:
class GradNetWrtX(nn.Cell):
    """定义网络输入的一阶求导"""

    def __init__(self, net):
        super(GradNetWrtX, self).__init__()
        self.net = net
        self.grad_op = ops.GradOperation()

    def construct(self, x, y):
        gradient_function = self.grad_op(self.net)
        return gradient_function(x, y)

output = GradNetWrtX(Net())(x, y)
print(output)
[[4.5099998 2.7       3.6000001]
 [4.5099998 2.7       3.6000001]]

接下来我们对上面的结果做一个解释。为便于分析,我们把上面的输入xy以及权重z表示成如下形式:

x = ms.Tensor([[x1, x2, x3], [x4, x5, x6]])
y = ms.Tensor([[y1, y2, y3], [y4, y5, y6], [y7, y8, y9]])
z = ms.Tensor([z])

根据MatMul算子定义可得前向结果:

\[output = [[(x_1 \cdot y_1 + x_2 \cdot y_4 + x_3 \cdot y_7) \cdot z, (x_1 \cdot y_2 + x_2 \cdot y_5 + x_3 \cdot y_8) \cdot z, (x_1 \cdot y_3 + x_2 \cdot y_6 + x_3 \cdot y_9) \cdot z],\]
\[[(x_4 \cdot y_1 + x_5 \cdot y_4 + x_6 \cdot y_7) \cdot z, (x_4 \cdot y_2 + x_5 \cdot y_5 + x_6 \cdot y_8) \cdot z, (x_4 \cdot y_3 + x_5 \cdot y_6 + x_6 \cdot y_9) \cdot z]] \tag{2}\]

梯度计算时由于MindSpore采用的是Reverse自动微分机制,会对输出结果求和后再对输入x求导:

  1. 求和公式:

\[\sum{output} = [(x_1 \cdot y_1 + x_2 \cdot y_4 + x_3 \cdot y_7) + (x_1 \cdot y_2 + x_2 \cdot y_5 + x_3 \cdot y_8) + (x_1 \cdot y_3 + x_2 \cdot y_6 + x_3 \cdot y_9)\]
\[+ (x_4 \cdot y_1 + x_5 \cdot y_4 + x_6 \cdot y_7) + (x_4 \cdot y_2 + x_5 \cdot y_5 + x_6 \cdot y_8) + (x_4 \cdot y_3 + x_5 \cdot y_6 + x_6 \cdot y_9)] \cdot z \tag{3}\]
  1. 求导公式:

\[\frac{\mathrm{d}(\sum{output})}{\mathrm{d}x} = [[(y_1 + y_2 + y_3) \cdot z, (y_4 + y_5 + y_6) \cdot z, (y_7 + y_8 + y_9) \cdot z],\]
\[[(y_1 + y_2 + y_3) \cdot z, (y_4 + y_5 + y_6) \cdot z, (y_7 + y_8 + y_9) \cdot z]] \tag{4}\]
  1. 计算结果:

\[\frac{\mathrm{d}(\sum{output})}{\mathrm{d}x} = [[4.51 \quad 2.7 \quad 3.6] [4.51 \quad 2.7 \quad 3.6]] \tag{5}\]

若考虑对xy输入求导,只需在GradNetWrtX中设置self.grad_op = GradOperation(get_all=True)

对权重进行求导

对权重进行求导,示例代码如下:

[3]:
class GradNetWrtZ(nn.Cell):
    """定义网络权重的一阶求导"""

    def __init__(self, net):
        super(GradNetWrtZ, self).__init__()
        self.net = net
        self.params = ms.ParameterTuple(net.trainable_params())
        self.grad_op = ops.GradOperation(get_by_list=True)

    def construct(self, x, y):
        gradient_function = self.grad_op(self.net, self.params)
        return gradient_function(x, y)

output = GradNetWrtZ(Net())(x, y)
print(output[0])
[21.536]

下面我们通过公式对上面的结果做一个解释。对权重的求导公式为:

\[\frac{\mathrm{d}(\sum{output})}{\mathrm{d}z} = (x_1 \cdot y_1 + x_2 \cdot y_4 + x_3 \cdot y_7) + (x_1 \cdot y_2 + x_2 \cdot y_5 + x_3 \cdot y_8) + (x_1 \cdot y_3 + x_2 \cdot y_6 + x_3 \cdot y_9)\]
\[+ (x_4 \cdot y_1 + x_5 \cdot y_4 + x_6 \cdot y_7) + (x_4 \cdot y_2 + x_5 \cdot y_5 + x_6 \cdot y_8) + (x_4 \cdot y_3 + x_5 \cdot y_6 + x_6 \cdot y_9) \tag{6}\]

计算结果:

\[\frac{\mathrm{d}(\sum{output})}{\mathrm{d}z} = [2.1536e+01] \tag{7}\]

梯度值缩放

可以通过sens_param参数控制梯度值的缩放:

[4]:
class GradNetWrtN(nn.Cell):
    """定义网络的一阶求导,控制梯度值缩放"""
    def __init__(self, net):
        super(GradNetWrtN, self).__init__()
        self.net = net
        self.grad_op = ops.GradOperation(sens_param=True)

        # 定义梯度值缩放
        self.grad_wrt_output = ms.Tensor([[0.1, 0.6, 0.2], [0.8, 1.3, 1.1]], dtype=ms.float32)

    def construct(self, x, y):
        gradient_function = self.grad_op(self.net)
        return gradient_function(x, y, self.grad_wrt_output)

output = GradNetWrtN(Net())(x, y)
print(output)
[[2.211 0.51  1.49 ]
 [5.588 2.68  4.07 ]]

为了方便对上面的结果进行解释,我们把self.grad_wrt_output记作如下形式:

self.grad_wrt_output = ms.Tensor([[s1, s2, s3], [s4, s5, s6]])

缩放后的输出值为原输出值与self.grad_wrt_output对应元素的乘积,公式为:

\[output = [[(x_1 \cdot y_1 + x_2 \cdot y_4 + x_3 \cdot y_7) \cdot z \cdot s_1, (x_1 \cdot y_2 + x_2 \cdot y_5 + x_3 \cdot y_8) \cdot z \cdot s_2, (x_1 \cdot y_3 + x_2 \cdot y_6 + x_3 \cdot y_9) \cdot z \cdot s_3],\]
\[[(x_4 \cdot y_1 + x_5 \cdot y_4 + x_6 \cdot y_7) \cdot z \cdot s_4, (x_4 \cdot y_2 + x_5 \cdot y_5 + x_6 \cdot y_8) \cdot z \cdot s_5, (x_4 \cdot y_3 + x_5 \cdot y_6 + x_6 \cdot y_9) \cdot z \cdot s_6]] \tag{8}\]

求导公式变为输出值总和对x的每个元素求导:

\[\frac{\mathrm{d}(\sum{output})}{\mathrm{d}x} = [[(s_1 \cdot y_1 + s_2 \cdot y_2 + s_3 \cdot y_3) \cdot z, (s_1 \cdot y_4 + s_2 \cdot y_5 + s_3 \cdot y_6) \cdot z, (s_1 \cdot y_7 + s_2 \cdot y_8 + s_3 \cdot y_9) \cdot z],\]
\[[(s_4 \cdot y_1 + s_5 \cdot y_2 + s_6 \cdot y_3) \cdot z, (s_4 \cdot y_4 + s_5 \cdot y_5 + s_6 \cdot y_6) \cdot z, (s_4 \cdot y_7 + s_5 \cdot y_8 + s_6 \cdot y_9) \cdot z]] \tag{9}\]

计算结果:

\[\frac{\mathrm{d}(\sum{output})}{\mathrm{d}x} = [[2.211 \quad 0.51 \quad 1.49][5.588 \quad 2.68 \quad 4.07]] \tag{10}\]

停止计算梯度

我们可以使用stop_gradient来停止计算指定算子的梯度,从而消除该算子对梯度的影响。

在上面一阶求导使用的矩阵相乘网络模型的基础上,我们再增加一个算子out2并禁止计算其梯度,得到自定义网络Net2,然后看一下对输入的求导结果情况。

示例代码如下:

[5]:
class Net(nn.Cell):

    def __init__(self):
        super(Net, self).__init__()
        self.matmul = ops.MatMul()

    def construct(self, x, y):
        out1 = self.matmul(x, y)
        out2 = self.matmul(x, y)
        out2 = ops.stop_gradient(out2)  # 停止计算out2算子的梯度
        out = out1 + out2
        return out

class GradNetWrtX(nn.Cell):

    def __init__(self, net):
        super(GradNetWrtX, self).__init__()
        self.net = net
        self.grad_op = ops.GradOperation()

    def construct(self, x, y):
        gradient_function = self.grad_op(self.net)
        return gradient_function(x, y)

output = GradNetWrtX(Net())(x, y)
print(output)
[[4.5099998 2.7       3.6000001]
 [4.5099998 2.7       3.6000001]]

从上面的打印可以看出,由于对out2设置了stop_gradient, 所以out2没有对梯度计算有任何的贡献,其输出结果与未加out2算子时一致。

下面我们删除out2 = stop_gradient(out2),再来看一下输出结果。示例代码为:

[6]:
class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.matmul = ops.MatMul()

    def construct(self, x, y):
        out1 = self.matmul(x, y)
        out2 = self.matmul(x, y)
        # out2 = stop_gradient(out2)
        out = out1 + out2
        return out

class GradNetWrtX(nn.Cell):
    def __init__(self, net):
        super(GradNetWrtX, self).__init__()
        self.net = net
        self.grad_op = ops.GradOperation()

    def construct(self, x, y):
        gradient_function = self.grad_op(self.net)
        return gradient_function(x, y)

output = GradNetWrtX(Net())(x, y)
print(output)
[[9.0199995 5.4       7.2000003]
 [9.0199995 5.4       7.2000003]]

打印结果可以看出,在我们把out2算子的梯度也计算进去之后,由于out2out1算子完全相同,因此它们产生的梯度也完全相同,所以我们可以看到,结果中每一项的值都变为了原来的两倍(存在精度误差)。

自定义反向传播函数

使用MindSpore构建神经网络时,需要继承 nn.Cell 类。当网络中存在一些尚未定义反向传播规则的操作,或者当我们想控制整个网络的梯度计算过程时,可以使用自定义 nn.Cell 对象反向传播函数的功能,形式为:

def bprop(self, ..., out, dout):
    return ...
  • 输入参数: 与正向部分相同的输入参数再加上 outdoutout 表示正向部分的计算结果, dout 表示回传到该 nn.Cell 对象的梯度。

  • 返回值: 关于正向部分每个输入的梯度,所以返回值的数量需要与正向部分输入的数量相同。

完整示例如下:

[7]:
import mindspore.nn as nn
import mindspore as ms
import mindspore.ops as ops

class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.matmul = ops.MatMul()

    def construct(self, x, y):
        out = self.matmul(x, y)
        return out

    def bprop(self, x, y, out, dout):
        dx = x + 1
        dy = y + 1
        return dx, dy


class GradNet(nn.Cell):
    def __init__(self, net):
        super(GradNet, self).__init__()
        self.net = net
        self.grad_op = ops.GradOperation(get_all=True)

    def construct(self, x, y):
        gradient_function = self.grad_op(self.net)
        return gradient_function(x, y)


x = ms.Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=ms.float32)
y = ms.Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=ms.float32)
out = GradNet(Net())(x, y)
print(out)
(Tensor(shape=[2, 3], dtype=Float32, value=
[[ 1.50000000e+00,  1.60000002e+00,  1.39999998e+00],
 [ 2.20000005e+00,  2.29999995e+00,  2.09999990e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
[[ 1.00999999e+00,  1.29999995e+00,  2.09999990e+00],
 [ 1.10000002e+00,  1.20000005e+00,  2.29999995e+00],
 [ 3.09999990e+00,  2.20000005e+00,  4.30000019e+00]]))

约束与限制:

  • bprop 函数的返回值数量为1时,也需要写成tuple的形式,即 return (dx,)

  • 图模式下, bprop 函数需要转换成图IR,所以需要遵循静态图语法,请参考静态图语法支持

  • 只支持返回关于正向部分输入的梯度,不支持返回关于 Parameter 的梯度。

  • 不支持在 bprop 中使用 Parameter

高阶求导

高阶微分在AI支持科学计算、二阶优化等领域均有应用。如分子动力学模拟中,利用神经网络训练势能时,损失函数中需计算神经网络输出对输入的导数,则反向传播便存在损失函数对输入、权重的二阶交叉导数。

此外,AI求解微分方程(如PINNs方法)还会存在输出对输入的二阶导数。又如二阶优化中,为了能够让神经网络快速收敛,牛顿法等需计算损失函数对权重的二阶导数。

MindSpore可通过多次求导的方式支持高阶导数,下面通过几类例子展开阐述。

单输入单输出高阶导数

例如Sin算子,其公式为:

\[f(x) = sin(x) \tag{1}\]

其一阶导数是:

\[f'(x) = cos(x) \tag{2}\]

其二阶导数为:

\[f''(x) = cos'(x) = -sin(x) \tag{3}\]

其二阶导数(-Sin)实现如下:

[8]:
import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore as ms

class Net(nn.Cell):
    """前向网络模型"""
    def __init__(self):
        super(Net, self).__init__()
        self.sin = ops.Sin()

    def construct(self, x):
        out = self.sin(x)
        return out

class Grad(nn.Cell):
    """一阶求导"""
    def __init__(self, network):
        super(Grad, self).__init__()
        self.grad = ops.GradOperation()
        self.network = network

    def construct(self, x):
        gout = self.grad(self.network)(x)
        return gout

class GradSec(nn.Cell):
    """二阶求导"""
    def __init__(self, network):
        super(GradSec, self).__init__()
        self.grad = ops.GradOperation()
        self.network = network

    def construct(self, x):
        gout = self.grad(self.network)(x)
        return gout

x_train = ms.Tensor(np.array([3.1415926]), dtype=ms.float32)

net = Net()
firstgrad = Grad(net)
secondgrad = GradSec(firstgrad)
output = secondgrad(x_train)

# 打印结果
result = np.around(output.asnumpy(), decimals=2)
print(result)
[-0.]

从上面的打印结果可以看出,\(-sin(3.1415926)\)的值接近于\(0\)

单输入多输出高阶导数

对如下公式求导:

\[f(x) = (f_1(x), f_2(x)) \tag{1}\]

其中:

\[f_1(x) = sin(x) \tag{2}\]
\[f_2(x) = cos(x) \tag{3}\]

梯度计算时由于MindSpore采用的是反向自动微分机制, 会对输出结果求和后再对输入求导。 因此其一阶导数是:

\[f'(x) = cos(x) -sin(x) \tag{4}\]

其二阶导数为:

\[f''(x) = -sin(x) - cos(x) \tag{5}\]
[9]:
import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore as ms

class Net(nn.Cell):
    """前向网络模型"""
    def __init__(self):
        super(Net, self).__init__()
        self.sin = ops.Sin()
        self.cos = ops.Cos()

    def construct(self, x):
        out1 = self.sin(x)
        out2 = self.cos(x)
        return out1, out2

class Grad(nn.Cell):
    """一阶求导"""
    def __init__(self, network):
        super(Grad, self).__init__()
        self.grad = ops.GradOperation()
        self.network = network

    def construct(self, x):
        gout = self.grad(self.network)(x)
        return gout

class GradSec(nn.Cell):
    """二阶求导"""
    def __init__(self, network):
        super(GradSec, self).__init__()
        self.grad = ops.GradOperation()
        self.network = network

    def construct(self, x):
        gout = self.grad(self.network)(x)
        return gout

x_train = ms.Tensor(np.array([3.1415926]), dtype=ms.float32)

net = Net()
firstgrad = Grad(net)
secondgrad = GradSec(firstgrad)
output = secondgrad(x_train)

# 打印结果
result = np.around(output.asnumpy(), decimals=2)
print(result)
[1.]

从上面的打印结果可以看出,\(-sin(3.1415926) - cos(3.1415926)\)的值接近于\(1\)

多输入多输出高阶导数

对如下公式求导:

\[f(x, y) = (f_1(x, y), f_2(x, y)) \tag{1}\]

其中:

\[f_1(x, y) = sin(x) - cos(y) \tag{2}\]
\[f_2(x, y) = cos(x) - sin(y) \tag{3}\]

梯度计算时由于MindSpore采用的是反向自动微分机制, 会对输出结果求和后再对输入求导。

求和:

\[\sum{output} = sin(x) + cos(x) - sin(y) - cos(y) \tag{4}\]

输出和关于输入\(x\)的一阶导数为:

\[\dfrac{\mathrm{d}\sum{output}}{\mathrm{d}x} = cos(x) - sin(x) \tag{5}\]

输出和关于输入\(x\)的二阶导数为:

\[\dfrac{\mathrm{d}\sum{output}^{2}}{\mathrm{d}^{2}x} = -sin(x) - cos(x) \tag{6}\]

输出和关于输入\(y\)的一阶导数为:

\[\dfrac{\mathrm{d}\sum{output}}{\mathrm{d}y} = -cos(y) + sin(y) \tag{7}\]

输出和关于输入\(y\)的二阶导数为:

\[\dfrac{\mathrm{d}\sum{output}^{2}}{\mathrm{d}^{2}y} = sin(y) + cos(y) \tag{8}\]
[10]:
import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore as ms

class Net(nn.Cell):
    """前向网络模型"""
    def __init__(self):
        super(Net, self).__init__()
        self.sin = ops.Sin()
        self.cos = ops.Cos()

    def construct(self, x, y):
        out1 = self.sin(x) - self.cos(y)
        out2 = self.cos(x) - self.sin(y)
        return out1, out2

class Grad(nn.Cell):
    """一阶求导"""
    def __init__(self, network):
        super(Grad, self).__init__()
        self.grad = ops.GradOperation(get_all=True)
        self.network = network

    def construct(self, x, y):
        gout = self.grad(self.network)(x, y)
        return gout

class GradSec(nn.Cell):
    """二阶求导"""
    def __init__(self, network):
        super(GradSec, self).__init__()
        self.grad = ops.GradOperation(get_all=True)
        self.network = network

    def construct(self, x, y):
        gout = self.grad(self.network)(x, y)
        return gout

x_train = ms.Tensor(np.array([3.1415926]), dtype=ms.float32)
y_train = ms.Tensor(np.array([3.1415926]), dtype=ms.float32)

net = Net()
firstgrad = Grad(net)
secondgrad = GradSec(firstgrad)
output = secondgrad(x_train, y_train)

# 打印结果
print(np.around(output[0].asnumpy(), decimals=2))
print(np.around(output[1].asnumpy(), decimals=2))
[1.]
[-1.]

从上面的打印结果可以看出,输出对输入\(x\)的二阶导数\(-sin(3.1415926) - cos(3.1415926)\)的值接近于\(1\), 输出对输入\(y\)的二阶导数\(sin(3.1415926) + cos(3.1415926)\)的值接近于\(-1\)

由于不同计算平台的精度可能存在差异,因此本章节中的代码在不同平台上的执行结果会存在微小的差别。