实现高阶自动微分
CPU
GPU
Ascend
全流程
初级
中级
高级
概述
高阶微分在AI支持科学计算、二阶优化等领域均有应用。如分子动力学模拟中,利用神经网络训练势能时[1],损失函数中需计算神经网络输出对输入的导数,则反向传播便存在损失函数对输入、权重的二阶交叉导数;此外,AI求解微分方程(如PINNs[2]方法)还会存在输出对输入的二阶导数。又如二阶优化中,为了能够让神经网络快速收敛,牛顿法等需计算损失函数对权重的二阶导数。以下将主要介绍MindSpore图模式下的高阶导数。
完整样例代码见:导数样例代码
一阶求导
首先回顾下MindSpore计算一阶导数方法mindspore.ops.GradOperation (get_all=False, get_by_list=False, sens_param=False)
,其中get_all
为False
时,只会对第一个输入求导,为True
时,会对所有输入求导;get_by_list
为False
时,不会对权重求导,为True
时,会对权重求导;sens_param
对网络的输出值做缩放以改变最终梯度,故其维度与输出维度保持一致。下面用MatMul算子的一阶求导做深入分析。
输入求导
对输入求导代码如下:
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import ParameterTuple, Parameter
from mindspore import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = ops.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, x, y):
x = x * self.z
out = self.matmul(x, y)
return out
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = ops.GradOperation()
def construct(self, x, y):
gradient_function = self.grad_op(self.net)
return gradient_function(x, y)
x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
output = GradNetWrtX(Net())(x, y)
print(output)
输出结果如下:
[[4.5099998 2.7 3.6000001] [4.5099998 2.7 3.6000001]]
为便于分析,输入x
、y
以及权重z
可以表示成如下形式:
x = Tensor([[x1, x2, x3], [x4, x5, x6]])
y = Tensor([[y1, y2, y3], [y4, y5, y6], [y7, y8, y9]])
z = Tensor([z])
根据MatMul算子定义可得前向结果:
\(output = [[(x1 \cdot y1 + x2 \cdot y4 + x3 \cdot y7) \cdot z, (x1 \cdot y2 + x2 \cdot y5 + x3 \cdot y8) \cdot z, (x1 \cdot y3 + x2 \cdot y6 + x3 \cdot y9) \cdot z]\),
\([(x4 \cdot y1 + x5 \cdot y4 + x6 \cdot y7) \cdot z, (x4 \cdot y2 + x5 \cdot y5 + x6 \cdot y8) \cdot z, (x4 \cdot y3 + x5 \cdot y6 + x6 \cdot y9) \cdot z]]\)
梯度计算时由于MindSpore采用的是Reverse[3]自动微分机制,会对输出结果求和后再对输入x
求导:
(1) 求和公式:
\(\sum{output} = [(x1 \cdot y1 + x2 \cdot y4 + x3 \cdot y7) + (x1 \cdot y2 + x2 \cdot y5 + x3 \cdot y8) + (x1 \cdot y3 + x2 \cdot y6 + x3 \cdot y9) +\)
\((x4 \cdot y1 + x5 \cdot y4 + x6 \cdot y7) + (x4 \cdot y2 + x5 \cdot y5 + x6 \cdot y8) + (x4 \cdot y3 + x5 \cdot y6 + x6 \cdot y9)] \cdot z\)
(2) 求导公式:
\(\frac{\mathrm{d}(\sum{output})}{\mathrm{d}x} = [[(y1 + y2 + y3) \cdot z,(y4 + y5 + y6) \cdot z,(y7 + y8 + y9) \cdot z],[(y1 + y2 + y3) \cdot z,(y4 + y5 + y6) \cdot z,(y7 + y8 + y9) \cdot z]]\)
(3) 计算结果:
\(\frac{\mathrm{d}(\sum{output})}{\mathrm{d}x} = [[4.5099998 \quad 2.7 \quad 3.6000001] [4.5099998 \quad 2.7 \quad 3.6000001]]\)
若考虑对x
、y
输入求导,只需在GradNetWrtX
中设置self.grad_op = GradOperation(get_all=True)
。
权重求导
若考虑对权重的求导,将GradNetWrtX
修改成:
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.params = ParameterTuple(net.trainable_params())
self.grad_op = ops.GradOperation(get_by_list=True)
def construct(self, x, y):
gradient_function = self.grad_op(self.net, self.params)
return gradient_function(x, y)
output = GradNetWrtX(Net())(x, y)
print(output)
输出结果如下:
[2.15359993e+01]
求导公式变为:
\(\frac{\mathrm{d}(\sum{output})}{\mathrm{d}z} = (x1 \cdot y1 + x2 \cdot y4 + x3 \cdot y7) + (x1 \cdot y2 + x2 \cdot y5 + x3 \cdot y8) + (x1 \cdot y3 + x2 \cdot y6 + x3 \cdot y9) + \)
\((x4 \cdot y1 + x5 \cdot y4 + x6 \cdot y7) + (x4 \cdot y2 + x5 \cdot y5 + x6 \cdot y8) + (x4 \cdot y3 + x5 \cdot y6 + x6 \cdot y9)\)
计算结果:
\(\frac{\mathrm{d}(\sum{output})}{\mathrm{d}z} = [2.15359993e+01]\)
梯度值缩放
可以通过sens_param
参数控制梯度值的缩放:
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = ops.GradOperation(sens_param=True)
self.grad_wrt_output = Tensor([[0.1, 0.6, 0.2], [0.8, 1.3, 1.1]], dtype=mstype.float32)
def construct(self, x, y):
gradient_function = self.grad_op(self.net)
return gradient_function(x, y, self.grad_wrt_output)
output = GradNetWrtX(Net())(x, y)
print(output)
输出结果如下:
[[2.211 0.51 1.49] [5.588 2.68 4.07]]
self.grad_wrt_output
可以记作如下形式:
self.grad_wrt_output = Tensor([[s1, s2, s3], [s4, s5, s6]])
缩放后的输出值为原输出值与self.grad_wrt_output
对应元素的乘积:
\(output = [[(x1 \cdot y1 + x2 \cdot y4 + x3 \cdot y7) \cdot z \cdot s1,(x1 \cdot y2 + x2 \cdot y5 + x3 \cdot y8) \cdot z \cdot s2,(x1 \cdot y3 + x2 \cdot y6 + x3 \cdot y9) \cdot z \cdot s3],\)
\([(x4 \cdot y1 + x5 \cdot y4 + x6 \cdot y7) \cdot z \cdot s4,(x4 \cdot y2 + x5 \cdot y5 + x6 \cdot y8) \cdot z \cdot s5,(x4 \cdot y3 + x5 \cdot y6 + x6 \cdot y9) \cdot z \cdot s6]]\)
求导公式变为输出值总和对x
的每个元素求导:
\(\frac{\mathrm{d}(\sum{output})}{\mathrm{d}x} = [[(s1 \cdot y1 + s2 \cdot y2 + s3 \cdot y3) \cdot z,(s1 \cdot y4 + s2 \cdot y5 + s3 \cdot y6) \cdot z,(s1 \cdot y7 + s2 \cdot y8 + s3 \cdot y9) \cdot z],\)
\([(s4 \cdot y1 + s5 \cdot y2 + s6 \cdot y3) \cdot z,(s4 \cdot y4 + s5 \cdot y5 + s6 \cdot y6) \cdot z,(s4 \cdot y7 + s5 \cdot y8 + s6 \cdot y9) \cdot z]]\)
如果想计算单个输出(例如output[0][0]
)对输入的导数,可以将相应位置的缩放值置为1,其他置为0;也可以改变网络结构:
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = ops.MatMul()
self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
def construct(self, x, y):
x = x * self.z
out = self.matmul(x, y)
return out[0][0]
output = GradNetWrtX(Net())(x, y)
print(output)
输出结果如下:
[[0.11 1.1 1.1] [0. 0. 0.]]
高阶求导
MindSpore可通过多次求导的方式支持高阶导数,下面通过几类例子展开阐述。
单输入单输出高阶导数
例如Sin算子,其二阶导数(-Sin)实现如下:
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.sin = ops.Sin()
def construct(self, x):
out = self.sin(x)
return out
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = ops.GradOperation()
self.network = network
def construct(self, x):
gout= self.grad(self.network)(x)
return gout
class GradSec(nn.Cell):
def __init__(self, network):
super(GradSec, self).__init__()
self.grad = ops.GradOperation()
self.network = network
def construct(self, x):
gout= self.grad(self.network)(x)
return gout
net=Net()
firstgrad = Grad(net) # first order
secondgrad = GradSec(firstgrad) # second order
x_train = Tensor(np.array([1.0], dtype=np.float32))
output = secondgrad(x_train)
print(output)
输出结果如下:
[-0.841471]
单输入多输出高阶导数
例如多输出的乘法运算,其高阶导数如下:
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.mul = ops.Mul()
def construct(self, x):
out = self.mul(x, x)
return out
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = ops.GradOperation(sens_param=False)
self.network = network
def construct(self, x):
gout = self.grad(self.network)(x)
return gout
class GradSec(nn.Cell):
def __init__(self, network):
super(GradSec, self).__init__()
self.grad = ops.GradOperation(sens_param=False)
self.network = network
def construct(self, x):
gout = self.grad(self.network)(x)
return gout
net=Net()
firstgrad = Grad(net) # first order
secondgrad = GradSec(firstgrad) # second order
x = Tensor([0.1, 0.2, 0.3], dtype=mstype.float32)
output = secondgrad(x)
print(output)
输出结果如下:
[2.00000000e+00, 2.00000000e+00, 2.00000000e+00]
多输入多输出高阶导数
例如神经网络有多个输入x
、y
,可以通过梯度缩放机制获得二阶导数dxdx
,dydy
,dxdy
,dydx
如下:
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.mul = ops.Mul()
def construct(self, x, y):
x_square = self.mul(x, x)
x_square_y = self.mul(x_square, y)
return x_square_y
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = ops.GradOperation(get_all=True, sens_param=False)
self.network = network
def construct(self, x, y):
gout = self.grad(self.network)(x, y) # return dx, dy
return gout
class GradSec(nn.Cell):
def __init__(self, network):
super(GradSec, self).__init__()
self.grad = ops.GradOperation(get_all=True, sens_param=True)
self.network = network
self.sens1 = Tensor(np.array([1]).astype('float32'))
self.sens2 = Tensor(np.array([0]).astype('float32'))
def construct(self, x, y):
dxdx, dxdy = self.grad(self.network)(x, y, (self.sens1,self.sens2))
dydx, dydy = self.grad(self.network)(x, y, (self.sens2,self.sens1))
return dxdx, dxdy, dydx, dydy
net = Net()
firstgrad = Grad(net) # first order
secondgrad = GradSec(firstgrad) # second order
x_train = Tensor(np.array([4],dtype=np.float32))
y_train = Tensor(np.array([5],dtype=np.float32))
dxdx, dxdy, dydx, dydy = secondgrad(x_train, y_train)
print(dxdx, dxdy, dydx, dydy)
输出结果如下:
[10] [8.] [8.] [0.]
具体地,一阶导数计算的结果是dx
、dy
:如果计算dxdx
,则一阶导数只需保留dx
,对应x
、y
的缩放值分别设置成1和0,即self.grad(self.network)(x, y, (self.sens1,self.sens2))
;同理计算dydy
,则一阶导数只保留dy
,对应x
、y
的sens_param
分别设置成0和1,即self.grad(self.network)(x, y, (self.sens2,self.sens1))
。
二阶微分算子支持情况
CPU支持算子:Square、 Exp、Neg、Mul、MatMul;
GPU支持算子:Pow、Log、Square、Exp、Neg、Mul、Div、MatMul、Sin、Cos、Tan、Atanh;
Ascend支持算子:Pow、Log、Square、Exp、Neg、Mul、Div、MatMul、Sin、Cos、Tan、Sinh、Cosh、Atanh。
引用
[1] Zhang L, Han J, Wang H, et al. Deep potential molecular dynamics: a scalable model with the accuracy of quantum mechanics[J]. Physical review letters, 2018, 120(14): 143001.
[2] Raissi M, Perdikaris P, Karniadakis G E. Physics informed deep learning (part i): Data-driven solutions of nonlinear partial differential equations[J]. arXiv preprint arXiv:1711.10561, 2017.
[3] Baydin A G, Pearlmutter B A, Radul A A, et al. Automatic differentiation in machine learning: a survey[J]. The Journal of Machine Learning Research, 2017, 18(1): 5595-5637.