流管理介绍

设备流

设备流是属于特定设备的线性执行序列。默认情况下,每个设备都使用自己的“默认”流,不需要显式地创建一个设备流。

每个流中的操作都按照它们创建的顺序执行,但是不同流中的操作可以以任何相对顺序并发执行,除非使用显式同步函数(如synchrize()wait_stream())。例如,下面的代码是不正确的:

...
s = ms.hal.Stream()  # Create a new stream.
A = Tensor(np.random.randn(20, 20), ms.float32)
B = ops.matmul(A, A)
with ms.hal.StreamCtx(s):
    # sum() may start execution before matmul() finishes!
    C = ops.sum(B)

current stream 是默认流,切换至非默认流时,用户有责任确保正确的同步。正确的示例为:

...
s = ms.hal.Stream()  # Create a new stream.
A = Tensor(np.random.randn(20, 20), ms.float32)
B = ops.matmul(A, A)
s.wait_stream(ms.hal.current_stream())  # Dispatch wait event to device.
with ms.hal.StreamCtx(s):
    C = ops.sum(B)

由于申请和释放都需要开销,所以建议在同一个Cell或函数中使用同一个流,而不是每次使用前都申请一个流。不推荐的写法如下:

# The following cell formats are not recommended:
class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        ...

    def construct(self, x):
        s1 = ms.hal.Stream()
        with ms.hal.StreamCtx(s1):
            ...

# The following function formats are not recommended:
def func(A):
    s1 = ms.hal.Stream()
    with ms.hal.StreamCtx(s1):
        ...

推荐的写法如下:

# The following cell formats are recommended:
class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.s1 = ms.hal.Stream()
        ...

    def construct(self, x):
        with ms.hal.StreamCtx(self.s1):
            ...

# The following function formats are recommended:
s1 = ms.hal.Stream()
def func(A, s1):
    with ms.hal.StreamCtx(s1):
        ...

设备事件

设备事件可用于监视设备的进度、精确测量计时以及同步设备流。出于易用性考虑,我们在Stream类中提供了几个封装接口,比如上面例子中使用的 wait_stream() 接口,你可以通过接口文档了解详情。

ms.hal.Event 最常见的用法是 record()wait() 的组合,以确保在多流网络中执行序能够满足用户的期望,如下所示:

s = ms.hal.Stream()     # Create a new stream.
ev1 = ms.hal.Event()    # Create a new event.
ev2 = ms.hal.Event()    # Create a new event.

A = Tensor(np.random.randn(20, 20), ms.float32)
B = ops.matmul(A, A)
ev1.record()
with ms.hal.StreamCtx(s):
    ev1.wait()  # Ensure 'B = ops.matmul(A, A)' is complete.
    C = ops.sum(B)
    ev2.record()
    ...
ev1.wait()      # Ensure 'C = ops.sum(B)' is complete.
D = C + 1

另外,由于Host算子的下发和设备算子的执行是异步的。在下面的例子中,Host上的 add_s0 持有的设备地址释放后,被 mess_output 申请到了,并将输出数据写入该设备地址。但是,此时设备可能会执行 add_s1 = add_s0 + 1 ,但是 add_s0 的设备内存已经被其他算子使用了。最终导致精度不正确。

import numpy as np
import mindspore as ms
from mindspore import context, Tensor, ops

context.set_context(mode=context.PYNATIVE_MODE)
s1 = ms.hal.Stream()

input_s0 = Tensor(np.random.randn(10, 32, 32, 32), ms.float32)
weight_s0 = Tensor(np.random.randn(32, 32, 3, 3), ms.float32)
mess_input = Tensor(np.random.randn(10, 32, 32, 32), ms.float32)

# Use default stream
add_s0 = input_s0 + 1
s1.wait_stream(ms.hal.default_stream())
with ms.hal.StreamCtx(s1):
  # Use conv2d because this operator takes a long time in device.
  conv_res = ops.conv2d(add_s0, weight_s0)
  add_s1 = add_s0 + 1
  ev = s1.record_event()
del add_s0

# Mess up the mem of add_s0. The value of add_s1 might be wrong.
mess_output = mess_input + 1
ev.wait()
add_1_s0 = add_s1 + 1

因此,MindSpore开发了一套自动延长跨流地址的生命周期的机制。你不需要手动执行类似 record_stream() 的函数来延长它的生命周期。总体原则是:MindSpore能够识别算子和输入输出所在的流,在识别出算子和输入输出跨流后(如 add_s0 来自stream0,算子在stream1上),对输入输出地址插入事件,并将地址和事件信息保存到显存池中。如果后续有“回流”(stream1上执行 record() ,stream0上执行 wait() ),则显存池中的地址和事件信息将被释放。如果没有“回流”,但是显存池分配失败,则所有事件将被同步并释放绑定的信息。