动静结合

下载Notebook下载样例代码查看源文件

当前在业界支持动态图和静态图两种模式,动态图通过解释执行,具有动态语法亲和性,表达灵活;静态图使用JIT(just in time)编译优化执行,偏静态语法,在语法上有较多限制。动态图和静态图的编译流程不一致,导致语法约束也不一致。

MindSpore针对动态图和静态图模式,首先统一API表达,在两种模式下使用相同的API;其次统一动态图和静态图的底层微分机制。

dynamic

实现原理

MindSpore支持使用ms_function装饰器来修饰需要用静态图执行的对象,从而实现动静结合的目的。下面我们通过一个简单的动静结合的示例来介绍其实现原理。示例代码如下:

[1]:
import numpy as np
import mindspore.nn as nn
import mindspore as ms
from mindspore import ms_function

class Add(nn.Cell):
    """自定义类实现x自身相加"""
    def construct(self, x):
        x = x + x
        x = x + x
        return x

class Mul(nn.Cell):
    """自定义类实现x自身相乘"""
    @ms_function  # 使用ms_function修饰,此函数以静态图方式执行
    def construct(self, x):
        x = x * x
        x = x * x
        return x

class Test(nn.Cell):
    """自定义类实现x先Add(x),后Mul(x),再Add(x)"""
    def __init__(self):
        super(Test, self).__init__()
        self.add = Add()
        self.mul = Mul()

    def construct(self, x):
        x = self.add(x)
        x = self.mul(x)
        x = self.add(x)
        return x

ms.set_context(mode=ms.PYNATIVE_MODE)
x = ms.Tensor(np.ones([3, 3], dtype=np.float32))
print("init x:\n", x)
net = Test()
x = net(x)
print("\nx:\n", x)
init x:
 [[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]

x:
 [[1024. 1024. 1024.]
 [1024. 1024. 1024.]
 [1024. 1024. 1024.]]

从上面的打印结果可以看出,经过Test运算后,x最终值为每个元素都是8的3*3矩阵。该用例按照执行序,编译的方式如下图所示:

msfunction

ms_function修饰的函数将会按照静态图的方式进行编译和执行。如果网络涉及到反向求导,被ms_function修饰的部分也将以整图的形式来生成反向图,并与前后单个算子的反向图连成一个整体后被下发执行。其中,缓存的策略与静态图的缓存策略一致,相同的函数对象在输入Shape和Type信息一致时,编译的图结构将会被缓存。

ms_function装饰器

为了提高动态图模式下的前向计算任务执行速度,MindSpore提供了ms_function装饰器,可以通过修饰Python函数或者Python类的成员函数使其被编译成计算图,通过图优化等技术提高运行速度。

使用方式

MindSpore支持在动态图下使用静态编译的方式来进行混合执行,通过使用ms_function装饰符来修饰需要用静态图来执行的函数对象,即可实现动态图和静态图的混合执行。

1. 修饰独立函数

使用ms_function装饰器时,可以对独立定义的函数进行修饰,使其在Graph模式下运行,示例如下:

[2]:
import numpy as np
import mindspore.ops as ops
import mindspore as ms
from mindspore import ms_function

# 设置运行模式为动态图模式
ms.set_context(mode=ms.PYNATIVE_MODE)

# 使用装饰器,指定静态图模式下执行
@ms_function
def add_func(x, y):
    return ops.add(x, y)

x = ms.Tensor(np.array([1.0, 2.0, 3.0]).astype(np.float32))
y = ms.Tensor(np.array([4.0, 5.0, 6.0]).astype(np.float32))

out = add_func(x, y)
print(out)
[5. 7. 9.]

在上面的示例代码中,虽然一开始设置了运行模式为动态图模式,但是由于使用了ms_function装饰器对函数add_func(x, y)进行了修饰,所以函数add_func(x, y)仍然是以静态图模式运行。

2. 修饰Cell的成员函数

使用ms_function装饰器时,可以对Cell的成员函数进行修饰,示例代码如下:

[3]:
import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore as ms
from mindspore import ms_function

# 设置运行模式为动态图模式
ms.set_context(mode=ms.PYNATIVE_MODE)

class Add(nn.Cell):

    @ms_function # 使用装饰器,指定静态图模式下执行
    def construct(self, x, y):
        out = x + y
        return out

x = ms.Tensor(np.array([1.0, 2.0, 3.0]).astype(np.float32))
y = ms.Tensor(np.array([4.0, 5.0, 6.0]).astype(np.float32))

grad_ops = ops.GradOperation(get_all=True)  # 定义求导操作
net = Add()
grad_out = grad_ops(net)(x, y)

print("Infer result:\n", net(x, y))

print("Gradient result:")
print("Grad x Tensor1:\n", grad_out[0])  # 对x求导
print("Grad y Tensor2:\n", grad_out[1])  # 对y求导
Infer result:
 [5. 7. 9.]
Gradient result:
Grad x Tensor1:
 [1. 1. 1.]
Grad y Tensor2:
 [1. 1. 1.]

从上面的打印结果可以看出,x与y相加的结果为[5, 7, 9], 对x求导的结果和对y求导的结果相同,都为[1, 1, 1]。

注意事项

在使用ms_function来修饰函数,加速执行效率时,请注意以下几点:

  1. ms_function修饰的函数须在静态图编译支持的语法范围内,包括但不限于数据类型等。

  2. ms_function修饰的函数所支持的控制流语法,与静态图保持一致。其中,仅对固定循环次数或者分支条件的控制流结构具有加速效果。

  3. 在PyNative模式下使用ms_function功能时,非ms_function修饰的部分支持断点调试;被ms_function修饰的部分由于是以静态图的方式编译,不支持断点调试。

  4. 由于ms_function修饰的函数将按照静态图的方式编译执行,因此ms_function不支持修饰的函数中含有Hook算子,也不支持修饰自定义Bprop函数。

  5. ms_function修饰的函数会受到静态图函数副作用的影响。函数副作用指:当调用函数时,除了函数返回值之外,还对主调用函数产生的附加影响,例如修改全局变量(函数外的变量),修改函数的参数等。

场景1:

[4]:
import numpy as np
import mindspore as ms
from mindspore import ms_function

# pylint: disable=W0612

value = 5

@ms_function
def func(x, y):
    out = x + y
    value = 1
    return out

ms.set_context(mode=ms.PYNATIVE_MODE)
x = ms.Tensor(np.array([1.0, 2.0, 3.0]).astype(np.float32))
y = ms.Tensor(np.array([1.0, 2.0, 3.0]).astype(np.float32))
func(x, y)
print(value)
5

该场景下,value是全局变量且在func函数中被修改。此时,如果用ms_function修饰func函数,全局变量value的值将不会被修改。原因是:静态图编译时,会优化掉与返回值无关的语句

场景2:

[5]:
import numpy as np
import mindspore.nn as nn
import mindspore as ms
from mindspore import ms_function

class Func(nn.Cell):
    def __init__(self):
        super(Func, self).__init__()
        self.value = 5

    @ms_function
    def construct(self, x):
        out = self.value + x
        return out

ms.set_context(mode=ms.PYNATIVE_MODE)
x = ms.Tensor(np.array([1.0, 2.0, 3.0]).astype(np.float32))
func = Func()
print("out1:", func(x))
func.value = 1
print("out2:", func(x))
out1: [6. 7. 8.]
out2: [6. 7. 8.]

从上面的打印可以看出,在修改了Func类的成员变量value的值为1之后,对成员函数construct的操作并无影响。这是因为在此场景下,用ms_function修饰了Func对象的construct成员函数,执行Func时将会以静态图的方式编译执行。由于静态图会缓存编译结果,第二次调用Func时,对value的修改不会生效。

  1. 加装了ms_function装饰器的函数中,如果包含不需要进行参数训练的算子(如MatMulAdd等算子),则这些算子可以在被装饰的函数中直接调用;如果被装饰的函数中包含了需要进行参数训练的算子(如Conv2DBatchNorm等算子),则这些算子必须在被装饰的函数之外完成实例化操作。下面我们通过示例代码对这两种场景进行说明。

场景1:在被装饰的函数中直接调用不需要进行参数训练的算子(示例中为mindspore.ops.Add)。示例代码如下:

[6]:
import numpy as np
import mindspore as ms
import mindspore.ops as ops
from mindspore import ms_function

ms.set_context(mode=ms.PYNATIVE_MODE)

add = ops.Add()

@ms_function
def add_fn(x, y):
    res = add(x, y)
    return res

x = ms.Tensor(np.array([1.0, 2.0, 3.0]).astype(np.float32))
y = ms.Tensor(np.array([4.0, 5.0, 6.0]).astype(np.float32))
z = add_fn(x, y)

print("x:", x.asnumpy(), "\ny:", y.asnumpy(), "\nz:", z.asnumpy())
x: [1. 2. 3.]
y: [4. 5. 6.]
z: [5. 7. 9.]

场景2:需要进行参数训练的算子(示例中为mindspore.nn.Conv2d),必须在被装饰的函数之外完成实例化操作,示例代码如下:

[7]:
import numpy as np
import mindspore.nn as nn
import mindspore as ms
from mindspore import ms_function

ms.set_context(mode=ms.PYNATIVE_MODE)

# 对函数conv_fn中的算子conv_obj完成实例化操作
conv_obj = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, stride=2, padding=0)
conv_obj.init_parameters_data()

@ms_function
def conv_fn(x):
    res = conv_obj(x)
    return res

input_data = np.random.randn(1, 3, 3, 3).astype(np.float32)
z = conv_fn(ms.Tensor(input_data))
print(z.asnumpy())
[[[[ 0.00829158 -0.02994147]
   [-0.09116832 -0.00181637]]

  [[-0.00519348 -0.02172063]
   [-0.04015012 -0.02083161]]

  [[ 0.00608188 -0.01443425]
   [-0.01468289  0.01200477]]

  [[ 0.00845292  0.00044869]
   [-0.00361492  0.01993337]]]]